# Fine Tune Stable Diffusion for Parameter-Based Pokémon Generation 

This project is based on the stable diffusion fine-tuning example generated by Lambda Labs. See more here: [Lambda Labs examples repo](https://github.com/LambdaLabsML/examples). 

To train this model, you will likely need at least one GPU with at least 32GB of VRAM. 

In [None]:
!git clone https://github.com/justinpinkney/stable-diffusion.git #copy the base code from the stable diffusion library
%cd stable-diffusion
!pip install --upgrade pip
!pip install -r requirements.txt

!pip install --upgrade keras

In [None]:
!nvidia-smi #ensure you have a GPU in your instance

In [None]:
#load in the dataset
from datasets import load_dataset
ds = load_dataset("pranavs28/pokemon-types", split="train")

In [None]:
#login to huggingface to get access to the stable diffusion base checkpoint
!pip install huggingface_hub
from huggingface_hub import notebook_login

notebook_login()

In [None]:
#download the stable diffusion base checkpoint
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4-full-ema.ckpt", use_auth_token=True)

In [None]:
#Set Training settings based on your GPU setup. 
BATCH_SIZE = 4
N_GPUS = 2
ACCUMULATE_BATCHES = 1

gpu_list = ",".join((str(x) for x in range(N_GPUS))) + ","
print(f"Using GPUs: {gpu_list}")

In [None]:
# Run training
!(python main.py \
    -t \
    --base configs/stable-diffusion/pokemon.yaml \
    --gpus "$gpu_list" \
    --scale_lr False \
    --num_nodes 1 \
    --check_val_every_n_epoch 10 \
    --finetune_from "$ckpt_path" \
    data.params.batch_size="$BATCH_SIZE" \
    lightning.trainer.accumulate_grad_batches="$ACCUMULATE_BATCHES" \
    data.params.validation.params.n_gpus="$NUM_GPUS" \
)

In [None]:
# Run the model. Replace the prompt with [type] type pokemon or [type1] and [type2] type pokemon for best results.
!(python scripts/txt2img.py \
    --prompt 'fire type pokemon' \
    --outdir 'outputs/generated_pokemon' \
    --H 512 --W 512 \
    --n_samples 4 \
    --config 'configs/stable-diffusion/pokemon.yaml' \
    --ckpt 'path/to/your/checkpoint')