## Step 1: Download the base model and sample some images from it

Install all the Huggingface python packages

In [None]:
%%capture
!pip install git+https://github.com/huggingface/diffusers.git
!pip install accelerate
!pip install datasets
!pip install bitsandbytes

Set an environment variable for the base model to be fine-tuned

In [None]:
%env MODEL_NAME=stabilityai/stable-diffusion-2-1
# %env MODEL_NAME=OFA-Sys/small-stable-diffusion-v0 #could fit to colab 15G but quality is not so good

Define a simple function to plot a list of images returned from the model when generating

In [None]:
def plot_images(images):
    from matplotlib import pyplot as plt
    plt.figure()
    f, axarr = plt.subplots(1, len(images), figsize=(20,10))
    for ax, img in zip(axarr.flatten(), images):
        ax.imshow(img)
        ax.axis('off')
    plt.show()

Import the relevant python libraries to load and sample from the Stable Diffusion model

In [None]:
import os
import torch

from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline

In this step we initialise the model and move it to the GPU (you need to choose the correct runtime on google colab to run this command). This step will also trigger the download of the model. The model is a few GB, it might take some time but it will be faster than what you expect. 😏

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(os.getenv('MODEL_NAME'), torch_dtype=torch.float16)
pipe = pipe.to("cuda")

We set the text prompt for generating the images and run the actual generation.

In [None]:
prompt = "isometric view of a skyscraper in the style of a city building game"
images = pipe(prompt, num_images_per_prompt=6).images

Let's plot the images from the base model with our function:

In [None]:
plot_images(images)

**IMPORTANT:** We need to free up the memory of the GPU to be able to start the actual training, let's delete the python variables and collect all the garbage using the garbage collector. Finally we use torch to empty the GPU memory

In [None]:
# Flush the GPU memory to be able to run the training
del pipe
del images

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

## Step 2: Fine-tune the model

Let's clone the dataset and the hugging face code which contains the finetuning script

In [None]:
!git clone https://huggingface.co/datasets/Loulblemo/diffusion_skyscrapers_city_building
!git clone https://github.com/huggingface/diffusers.git

Set some flags for the finetuning script, the dataset to be used and the number of iterations. Since the dataset we chose is small and we want to finetune quickly using the colab free tier, let's only finetune for 50 epochs

In [None]:
%%capture
%env dataset_name=diffusion_skyscrapers_city_building
# No need to train the model for long to see meaningful results.
%env max_training_epochs = 50

In [None]:
!pip install peft


Now we run the actual fine-tuning script. **IMPORTANT:** Remember if you want to train on google's free T4 GPU it's crucial to add the flag

```
--use_8bit_adam
```
We will save the model in the *city-building-model* folder


In [None]:
!accelerate launch diffusers/examples/text_to_image/train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --use_ema \
  --use_8bit_adam \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --max_train_steps=$max_training_epochs \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="city-building-model"


## Step 3: Sample from the finetuned model

We load the new model in the GPU (this time we don't need to Download anything as the fine-tuned model has been saved locally) and generate some more images

In [None]:
pipe = StableDiffusionPipeline.from_pretrained('city-building-model', torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "isometric view of a skyscraper in the style of a city building game"
images = pipe(prompt, num_images_per_prompt=6).images

The new images! Hopefully the quality improved thanks to the finetuning process. You can play with the number of epochs to see how the fine-tuning process impacts the final output 💪

In [None]:
plot_images(images)