## Hardware Check

Check type of GPU and VRAM available.

In [None]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

## Storj Settings

In [None]:
import os
STORJ_ACCESS = "" # Leave empty to use the Storj access grant configured as STORJ_ACCESS env var
if not STORJ_ACCESS:
    STORJ_ACCESS = os.getenv('STORJ_ACCESS')
!uplink access inspect $STORJ_ACCESS > /dev/null && echo "Will use this Storj access grant: $STORJ_ACCESS" || echo "Invalid Storj access grant: $STORJ_ACCESS"

In [None]:
STORJ_BUCKET = "" # Leave empty to use the Storj bucket configured as STORJ_BUCKET env var
if not STORJ_BUCKET:
    STORJ_BUCKET = os.getenv('STORJ_BUCKET')
!uplink ls --access $STORJ_ACCESS sj://$STORJ_BUCKET > /dev/null && echo "Will use this Storj bucket: $STORJ_BUCKET" || echo "Inaccessible Storj bucket: $STORJ_BUCKET"

## Training dataset location

In [None]:
instance_name = "abc"
class_name = "cat"
dataset_dir = f"/workspace/data/{instance_name}"
!mkdir -p $dataset_dir
STORJ_DATASET_PATH = "dataset/512/"

if STORJ_DATASET_PATH:
    !echo "Downloading dataset from Storj"
    !uplink cp --access $STORJ_ACCESS --recursive sj://$STORJ_BUCKET/$STORJ_DATASET_PATH $dataset_dir

#@markdown Run to generate a grid of preview images from the dataset.
import os
from PIL import Image
import matplotlib.pyplot as plt

def display_image_grid(folder_path, columns):
    image_files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
    num_images = len(image_files)
    rows = num_images // columns + 1
    images = []
    
    for i in range(num_images):
        image_path = os.path.join(folder_path, image_files[i])
        image = Image.open(image_path)
        images.append(image)
    
    fig = plt.figure(figsize=(16, 16))
    
    for i in range(num_images):
        plt.subplot(rows, columns, i+1)
        plt.imshow(images[i])
        plt.axis('off')

    plt.tight_layout()
    plt.show()

display_image_grid(dataset_dir, 5)

## Settings and run

In [None]:
#@markdown Name/Path of the initial model.
MODEL_NAME = "runwayml/stable-diffusion-v1-5" #@param {type:"string"}

#@markdown Enter the directory name to save model at.

OUTPUT_DIR = f"/workspace/stable-diffusion-weights/{instance_name}" #@param {type:"string"}

print(f"[*] Weights will be saved at {OUTPUT_DIR}")

!mkdir -p $OUTPUT_DIR

In [None]:
# You can also add multiple concepts here. Try tweaking `--max_train_steps` accordingly.

concepts_list = [
    {
        "instance_prompt":      f"photo of {instance_name} {class_name}",
        "class_prompt":         f"photo of a {class_name}",
        "instance_data_dir":    dataset_dir,
        "class_data_dir":       f"/workspace/data/{class_name}"
    },
]

# `class_data_dir` contains regularization images
import json
import os
for c in concepts_list:
    os.makedirs(c["instance_data_dir"], exist_ok=True)

with open("concepts_list.json", "w") as f:
    json.dump(concepts_list, f, indent=4)

In [None]:
%pip install git+https://github.com/ShivamShrirao/diffusers
%pip install -U --pre triton
%pip install accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort safetensors xformers

# Start Training

Use the table below to choose the best flags based on your memory and speed requirements. Tested on Tesla T4 GPU.


| `fp16` | `train_batch_size` | `gradient_accumulation_steps` | `gradient_checkpointing` | `use_8bit_adam` | GB VRAM usage | Speed (it/s) |
| ---- | ------------------ | ----------------------------- | ----------------------- | --------------- | ---------- | ------------ |
| fp16 | 1                  | 1                             | TRUE                    | TRUE            | 9.92       | 0.93         |
| no   | 1                  | 1                             | TRUE                    | TRUE            | 10.08      | 0.42         |
| fp16 | 2                  | 1                             | TRUE                    | TRUE            | 10.4       | 0.66         |
| fp16 | 1                  | 1                             | FALSE                   | TRUE            | 11.17      | 1.14         |
| no   | 1                  | 1                             | FALSE                   | TRUE            | 11.17      | 0.49         |
| fp16 | 1                  | 2                             | TRUE                    | TRUE            | 11.56      | 1            |
| fp16 | 2                  | 1                             | FALSE                   | TRUE            | 13.67      | 0.82         |
| fp16 | 1                  | 2                             | FALSE                   | TRUE            | 13.7       | 0.83          |
| fp16 | 1                  | 1                             | TRUE                    | FALSE           | 15.79      | 0.77         |


Add `--gradient_checkpointing` flag for around 9.92 GB VRAM usage.

remove `--use_8bit_adam` flag for full precision. Requires 15.79 GB with `--gradient_checkpointing` else 17.8 GB.

remove `--train_text_encoder` flag to reduce memory usage further, degrades output quality.

In [None]:
!python3 dreambooth/train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse" \
  --output_dir=$OUTPUT_DIR \
  --revision="fp16" \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --seed=1337 \
  --resolution=512 \
  --train_batch_size=1 \
  --train_text_encoder \
  --mixed_precision="fp16" \
  --use_8bit_adam \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_class_images=50 \
  --sample_batch_size=4 \
  --max_train_steps=800 \
  --save_interval=10000 \
  --save_sample_prompt=f"photo of {instance_name} {class_name}" \
  --concepts_list="concepts_list.json"

# Reduce the `--save_interval` to lower than `--max_train_steps` to save weights from intermediate steps.
# `--save_sample_prompt` can be same as `--instance_prompt` to generate intermediate samples (saved along with weights in samples directory).

In [None]:
#@markdown Specify the weights directory to use (leave blank for latest)
WEIGHTS_DIR = "" #@param {type:"string"}
if WEIGHTS_DIR == "":
    from natsort import natsorted
    from glob import glob
    import os
    WEIGHTS_DIR = natsorted(glob(OUTPUT_DIR + os.sep + "*"))[-1]
print(f"[*] WEIGHTS_DIR={WEIGHTS_DIR}")

In [None]:
#@markdown Run to generate a grid of preview images from the last saved weights.
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

weights_folder = OUTPUT_DIR
folders = sorted([f for f in os.listdir(weights_folder) if f != "0"], key=lambda x: int(x))

row = len(folders)
col = len(os.listdir(os.path.join(weights_folder, folders[0], "samples")))
scale = 4
fig, axes = plt.subplots(row, col, figsize=(col*scale, row*scale), gridspec_kw={'hspace': 0, 'wspace': 0})

for i, folder in enumerate(folders):
    folder_path = os.path.join(weights_folder, folder)
    image_folder = os.path.join(folder_path, "samples")
    images = [f for f in os.listdir(image_folder)]
    for j, image in enumerate(images):
        if row == 1:
            currAxes = axes[j]
        else:
            currAxes = axes[i, j]
        if i == 0:
            currAxes.set_title(f"Image {j}")
        if j == 0:
            currAxes.text(-0.1, 0.5, folder, rotation=0, va='center', ha='center', transform=currAxes.transAxes)
        image_path = os.path.join(image_folder, image)
        img = mpimg.imread(image_path)
        currAxes.imshow(img, cmap='gray')
        currAxes.axis('off')
        
plt.tight_layout()
plt.savefig('grid.png', dpi=72)

## Convert weights to ckpt to use in web UIs like AUTOMATIC1111.

In [None]:
#@markdown Run conversion.
steps = os.path.basename(WEIGHTS_DIR)
ckpt_path = f"/workspace/stable-diffusion-webui/models/Stable-diffusion/sd15-{instance_name}-{class_name}-{steps}.ckpt"

half_arg = ""
#@markdown  Whether to convert to fp16, takes half the space (2GB).
fp16 = True #@param {type: "boolean"}
if fp16:
    half_arg = "--half"
!python dreambooth/convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR  --checkpoint_path $ckpt_path $half_arg
print(f"[*] Converted ckpt saved at {ckpt_path}")

## Inference

In [None]:
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
from IPython.display import display

model_path = WEIGHTS_DIR             # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive

pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
g_cuda = None

In [None]:
#@markdown Can set random seed here for reproducibility.
g_cuda = torch.Generator(device='cuda')
seed = 52362 #@param {type:"number"}
g_cuda.manual_seed(seed)

In [None]:
#@title Run for generating images.

prompt = f"photo of {instance_name} {class_name} in a bucket" #@param {type:"string"}
negative_prompt = "" #@param {type:"string"}
num_samples = 4 #@param {type:"number"}
guidance_scale = 7.5 #@param {type:"number"}
num_inference_steps = 24 #@param {type:"number"}
height = 512 #@param {type:"number"}
width = 512 #@param {type:"number"}

with autocast("cuda"), torch.inference_mode():
    images = pipe(
        prompt,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        num_images_per_prompt=num_samples,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=g_cuda
    ).images

for img in images:
    display(img)

## Refresh model list in Stable Diffusion Web UI

In [None]:
!curl https://$RUNPOD_POD_ID-7860.proxy.runpod.net/sdapi/v1/refresh-checkpoints  -d ''

## Upload trained model to Storj

In [None]:
STORJ_MODEL_PATH = "" # Leave empty to use the model path configured as STORJ_MODEL_PATH env var
if not STORJ_MODEL_PATH:
    STORJ_MODEL_PATH = os.getenv('STORJ_MODEL_PATH')
!uplink cp --access $STORJ_ACCESS --parallelism 16 $ckpt_path sj://$STORJ_BUCKET/$STORJ_MODEL_PATH

In [None]:
#@title Free runtime memory
exit()