<a href="https://colab.research.google.com/github/wongus/GAN-experimentation/blob/main/GAN_custom_subjects.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Upload multiple images and train the model (takes ~25 min.)
Run this cell once to intialise model setup and training, upload images and enter the key.

In [None]:
#@title

import os
from google.colab import files
import shutil
from IPython.display import clear_output
import json
import ipywidgets as widgets
from natsort import natsorted
from glob import glob
import torch
from torch import autocast
from IPython.display import display

concepts_list = [
    {
        "instance_prompt":      "T*",
        "class_prompt":         "person",
        "instance_data_dir":    "/content/data/T*",
        "class_data_dir":       "/content/data/person"
    }
]

for c in concepts_list:
    os.makedirs(c["instance_data_dir"], exist_ok=True)

with open("concepts_list.json", "w") as f:
    json.dump(concepts_list, f, indent=4)

for c in concepts_list:
  print(f"Uploading instance images for `{c['instance_prompt']}`")
  uploaded = files.upload()
  for filename in uploaded.keys():
    dst_path = os.path.join(c['instance_data_dir'], filename)
    shutil.move(filename, dst_path)

clear_output(wait=True)

from getpass import getpass
print('Key:')
key = getpass()
clear_output(wait=True)

!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py
!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py
%pip install -qq git+https://github.com/ShivamShrirao/diffusers --quiet
%pip install -q -U --pre triton --quiet
%pip install -q accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort --quiet
%pip install --no-deps -q https://github.com/brian6091/xformers-wheels/releases/download/0.0.15.dev0%2B4c06c79/xformers-0.0.15.dev0+4c06c79.d20221205-cp38-cp38-linux_x86_64.whl --quiet

from diffusers import StableDiffusionPipeline, DDIMScheduler

!mkdir -p ~/.huggingface
!echo -n f'{key}' > ~/.huggingface/token

clear_output(wait=True)

save_to_gdrive = False

MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "/content/stable_diffusion_weights/T*"

!mkdir -p $OUTPUT_DIR

out = widgets.Output(
  layout=widgets.Layout(height='200px', overflow='scroll')
)

display(out)

with out:
  !accelerate launch train_dreambooth.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse" \
    --output_dir=$OUTPUT_DIR \
    --revision="fp16" \
    --with_prior_preservation --prior_loss_weight=1.0 \
    --seed=1337 \
    --resolution=512 \
    --train_batch_size=1 \
    --train_text_encoder \
    --mixed_precision="fp16" \
    --use_8bit_adam \
    --gradient_accumulation_steps=1 \
    --learning_rate=1e-6 \
    --lr_scheduler="constant" \
    --lr_warmup_steps=0 \
    --num_class_images=50 \
    --sample_batch_size=4 \
    --max_train_steps=800 \
    --save_interval=10000 \
    --save_sample_prompt="T*" \
    --concepts_list="concepts_list.json"

  clear_output(wait=True)

WEIGHTS_DIR = natsorted(glob(OUTPUT_DIR + os.sep + "*"))[-1]

model_path = WEIGHTS_DIR

scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16).to("cuda")

g_cuda = None

clear_output(wait=True)
print('Done, proceed to image generation.')

## Image generation (use T* to specify the subject of the uploaded photos)
This cell can be run multiple times with either the same- or different prompts. Results should vary each time.

In [None]:
g_cuda = torch.Generator(device='cuda')

prompt = "high definition picture of T*" #@param {type:"string"}
negative_prompt = "" #@param {type:"string"}
num_samples = 5 #@param {type:"number"}
guidance_scale = 7.5 #@param {type:"number"}
num_inference_steps = 50 #@param {type:"number"}
seed = 0 #@param {type:"number"}
height = 512 #@param {type:"number"}
width = 512 #@param {type:"number"}

g_cuda.manual_seed(seed)

with autocast("cuda"), torch.inference_mode():
    images = pipe(
        prompt,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        num_images_per_prompt=num_samples,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=g_cuda
    ).images

for img in images:
    display(img)