This notebook will guide you through fine-tuning a Stable Diffusion model with [DreamBooth](https://dreambooth.github.io/).

Unlike similar notebooks, this one focuses on securely fetching the training dataset from a Storj bucket and securely uploading the trained checkpoints back to the bucket for future use.

## Hardware Check

Check the type of GPU and VRAM available.

In [None]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

## Install dependencies

In [None]:
%pip install git+https://github.com/ShivamShrirao/diffusers
%pip install -U --pre triton
%pip install accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort safetensors xformers matplotlib

## Storj Settings

Configure the [access grant](https://docs.storj.io/dcs/concepts/access/access-grants) to the Storj bucket that contains the training dataset and where trained checkpoints will be uploaded.

In [None]:
######### SET THE VALUE OF THE VARIABLE BELOW #########

STORJ_ACCESS = "" # Leave empty to use the Storj access grant configured as STORJ_ACCESS env var

######### DO NOT MODIFY THE CODE BELOW #########

import os
if not STORJ_ACCESS:
    STORJ_ACCESS = os.getenv('STORJ_ACCESS')
!uplink access inspect $STORJ_ACCESS --interactive=false > /dev/null && echo "Will use this Storj access grant: $STORJ_ACCESS" || echo "Invalid Storj access grant: $STORJ_ACCESS"

In [None]:
######### SET THE VALUE OF THE VARIABLE BELOW #########

STORJ_BUCKET = "" # Leave empty to use the Storj bucket configured as STORJ_BUCKET env var

######### DO NOT MODIFY THE CODE BELOW #########

if not STORJ_BUCKET:
    STORJ_BUCKET = os.getenv('STORJ_BUCKET')
!uplink ls --access $STORJ_ACCESS --interactive=false sj://$STORJ_BUCKET > /dev/null && echo "Will use this Storj bucket: $STORJ_BUCKET" || echo "Inaccessible Storj bucket: $STORJ_BUCKET"

## Choose the rare unique identifier and class name for the training

In [None]:
######### SET THE VALUE OF THE VARIABLEs BELOW #########

UNIQUE_IDENTIFIER = "clgr" # Should be a rare string of characters, not a real word
CLASS_NAME = "cat" # The subject class the unique identifier relates to

## Download training dataset

Set the path to the input images relative to the root of the Storj bucket.

In [None]:
######### SET THE VALUE OF THE VARIABLE BELOW #########

STORJ_DATASET_PATH = "path/to/the/dataset/"

######### DO NOT MODIFY THE CODE BELOW #########

dataset_dir = f"/workspace/data/{UNIQUE_IDENTIFIER}"
!mkdir -p $dataset_dir

if STORJ_DATASET_PATH:
    !echo "Downloading dataset from Storj"
    !uplink cp --access $STORJ_ACCESS --recursive sj://$STORJ_BUCKET/$STORJ_DATASET_PATH $dataset_dir

# generates a grid of preview images from the dataset.
import os
from PIL import Image
import matplotlib.pyplot as plt

def display_image_grid(folder_path, columns):
    image_files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
    num_images = len(image_files)
    rows = num_images // columns + 1
    images = []
    
    for i in range(num_images):
        image_path = os.path.join(folder_path, image_files[i])
        image = Image.open(image_path)
        images.append(image)
    
    fig = plt.figure(figsize=(16, 16))
    
    for i in range(num_images):
        plt.subplot(rows, columns, i+1)
        plt.imshow(images[i])
        plt.axis('off')

    plt.tight_layout()
    plt.show()

display_image_grid(dataset_dir, 5)

## Choose the base model for the training

If unsure, leave it as it is.

In [None]:
######### SET THE VALUE OF THE VARIABLE BELOW #########

MODEL_NAME = "runwayml/stable-diffusion-v1-5"

######### DO NOT MODIFY THE CODE BELOW #########

OUTPUT_DIR = f"/workspace/stable-diffusion-weights/{UNIQUE_IDENTIFIER}"

print(f"[*] Weights will be saved at {OUTPUT_DIR}")

!mkdir -p $OUTPUT_DIR

In [None]:
######### DO NOT MODIFY THE CODE BELOW #########

concepts_list = [
    {
        "instance_prompt":      f"photo of {UNIQUE_IDENTIFIER} {CLASS_NAME}",
        "class_prompt":         f"photo of a {CLASS_NAME}",
        "instance_data_dir":    dataset_dir,
        "class_data_dir":       f"/workspace/data/{CLASS_NAME}"
    },
]

# `class_data_dir` contains regularization images
import json
import os
for c in concepts_list:
    os.makedirs(c["instance_data_dir"], exist_ok=True)

with open("dreambooth/concepts_list.json", "w") as f:
    json.dump(concepts_list, f, indent=4)

# Start Training

Use the table below to choose the best flags based on your memory and speed requirements. Tested on Tesla T4 GPU.


| `fp16` | `train_batch_size` | `gradient_accumulation_steps` | `gradient_checkpointing` | `use_8bit_adam` | GB VRAM usage | Speed (it/s) |
| ---- | ------------------ | ----------------------------- | ----------------------- | --------------- | ---------- | ------------ |
| fp16 | 1                  | 1                             | TRUE                    | TRUE            | 9.92       | 0.93         |
| no   | 1                  | 1                             | TRUE                    | TRUE            | 10.08      | 0.42         |
| fp16 | 2                  | 1                             | TRUE                    | TRUE            | 10.4       | 0.66         |
| fp16 | 1                  | 1                             | FALSE                   | TRUE            | 11.17      | 1.14         |
| no   | 1                  | 1                             | FALSE                   | TRUE            | 11.17      | 0.49         |
| fp16 | 1                  | 2                             | TRUE                    | TRUE            | 11.56      | 1            |
| fp16 | 2                  | 1                             | FALSE                   | TRUE            | 13.67      | 0.82         |
| fp16 | 1                  | 2                             | FALSE                   | TRUE            | 13.7       | 0.83          |
| fp16 | 1                  | 1                             | TRUE                    | FALSE           | 15.79      | 0.77         |


Add `--gradient_checkpointing` flag for around 9.92 GB VRAM usage.

remove `--use_8bit_adam` flag for full precision. Requires 15.79 GB with `--gradient_checkpointing` else 17.8 GB.

remove `--train_text_encoder` flag to reduce memory usage further, degrades output quality.

If unsure, leave everything as is.

In [None]:
!python3 dreambooth/train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse" \
  --output_dir=$OUTPUT_DIR \
  --revision="fp16" \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --seed=1337 \
  --resolution=512 \
  --train_batch_size=1 \
  --train_text_encoder \
  --mixed_precision="fp16" \
  --use_8bit_adam \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_class_images=50 \
  --sample_batch_size=4 \
  --max_train_steps=800 \
  --save_interval=10000 \
  --save_sample_prompt=f"photo of {UNIQUE_IDENTIFIER} {CLASS_NAME}" \
  --concepts_list="dreambooth/concepts_list.json"

# Reduce the `--save_interval` to lower than `--max_train_steps` to save weights from intermediate steps.
# `--save_sample_prompt` can be same as `--instance_prompt` to generate intermediate samples (saved along with weights in samples directory).

In [None]:
######### DO NOT MODIFY THE CODE BELOW #########

WEIGHTS_DIR = ""
if WEIGHTS_DIR == "":
    from natsort import natsorted
    from glob import glob
    import os
    WEIGHTS_DIR = natsorted(glob(OUTPUT_DIR + os.sep + "*"))[-1]
print(f"[*] WEIGHTS_DIR={WEIGHTS_DIR}")

## Generate a grid of preview images from the last saved weights

In [None]:
######### DO NOT MODIFY THE CODE BELOW #########

import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

weights_folder = OUTPUT_DIR
folders = sorted([f for f in os.listdir(weights_folder) if f != "0"], key=lambda x: int(x))

row = len(folders)
col = len(os.listdir(os.path.join(weights_folder, folders[0], "samples")))
scale = 4
fig, axes = plt.subplots(row, col, figsize=(col*scale, row*scale), gridspec_kw={'hspace': 0, 'wspace': 0})

for i, folder in enumerate(folders):
    folder_path = os.path.join(weights_folder, folder)
    image_folder = os.path.join(folder_path, "samples")
    images = [f for f in os.listdir(image_folder)]
    for j, image in enumerate(images):
        if row == 1:
            currAxes = axes[j]
        else:
            currAxes = axes[i, j]
        if i == 0:
            currAxes.set_title(f"Image {j}")
        if j == 0:
            currAxes.text(-0.1, 0.5, folder, rotation=0, va='center', ha='center', transform=currAxes.transAxes)
        image_path = os.path.join(image_folder, image)
        img = mpimg.imread(image_path)
        currAxes.imshow(img, cmap='gray')
        currAxes.axis('off')
        
plt.tight_layout()
plt.savefig('dreambooth/grid.png', dpi=72)

## Convert weights to checkpoints to use in Stable Diffusion Web UIs

In [None]:
######### DO NOT MODIFY THE CODE BELOW #########

steps = os.path.basename(WEIGHTS_DIR)
ckpt_path = f"/workspace/stable-diffusion-webui/models/Stable-diffusion/sd15-{UNIQUE_IDENTIFIER}-{CLASS_NAME}-{steps}.ckpt"

half_arg = ""
# Whether to convert to fp16, takes half the space (2GB).
fp16 = True
if fp16:
    half_arg = "--half"
!python dreambooth/convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR  --checkpoint_path $ckpt_path $half_arg
print(f"[*] Converted ckpt saved at {ckpt_path}")

## Test inference

In [None]:
######### DO NOT MODIFY THE CODE BELOW #########

import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
from IPython.display import display

model_path = WEIGHTS_DIR

pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
g_cuda = None

# Can set random seed here for reproducibility.
g_cuda = torch.Generator(device='cuda')
seed = 52362
g_cuda.manual_seed(seed)

# Generate the images.

prompt = f"photo of {UNIQUE_IDENTIFIER} {CLASS_NAME} in a bucket"
negative_prompt = ""
num_samples = 4
guidance_scale = 7.5
num_inference_steps = 24
height = 512
width = 512

with autocast("cuda"), torch.inference_mode():
    images = pipe(
        prompt,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        num_images_per_prompt=num_samples,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=g_cuda
    ).images

for img in images:
    display(img)

## Refresh model list in Stable Diffusion Web UI

In [None]:
!curl https://$RUNPOD_POD_ID-7860.proxy.runpod.net/sdapi/v1/refresh-checkpoints  -d ''

## Upload trained model to Storj

Set the path to upload the trained checkpoint relative to the root of the Storj bucket.

In [None]:
######### SET THE VALUE OF THE VARIABLE BELOW #########

STORJ_MODEL_PATH = "" # Leave empty to use the model path configured as STORJ_MODEL_PATH env var

######### DO NOT MODIFY THE CODE BELOW #########

import os
if not STORJ_MODEL_PATH:
    STORJ_MODEL_PATH = os.getenv('STORJ_MODEL_PATH')

if STORJ_MODEL_PATH:
    !uplink cp --access $STORJ_ACCESS --parallelism 16 $ckpt_path sj://$STORJ_BUCKET/$STORJ_MODEL_PATH
else:
    print("[*] Path not set, won't upload the checkpoint.")