<a href="https://colab.research.google.com/github/sayakpaul/stable-diffusion-keras-ft/blob/main/notebooks/generate_images_with_finetuned_stable_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Overview

This notebook shows a side by side comparison between images generated by original and fine-tuned KerasCV's StableDiffusion models. The finetuning methods that we used could be found in this [repository](https://github.com/sayakpaul/stabe-diffusion-keras-ft).

## Setup

In [None]:
!pip install git+https://github.com/keras-team/keras-cv -q

In [None]:
import keras_cv
from tensorflow import keras
import matplotlib.pyplot as plt

print(keras_cv.__version__)

## Prepare the models

### Download and load fine-tuned weights

We fine-tuned Stable Diffusion on the [Pokemon Dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) that is hosted in Hugging Face Datasets, and the fine-tuned model weights are available in this Hugging Face Hub [repository](https://huggingface.co/sayakpaul/kerascv_sd_pokemon_finetuned/tree/main).

Note that you can also provide a path to any other fine-tuned weights here. 

In [None]:
weights_path = keras.utils.get_file(
    origin="/models/ckpt_epochs_72_res_512_mp_True.h5"
)

img_height = img_width = 512
pokemon_model = keras_cv.models.StableDiffusion(
    img_width=img_width, img_height=img_height
)
pokemon_model.diffusion_model.load_weights(weights_path)

Even though the checkpoint name suggests "72 epochs", note that it was trained for 72 epochs fully. This checkpoint was retrieved after 60 epochs of fine-tuning.

### Load original weights

In [None]:
original_model = keras_cv.models.StableDiffusion(
    img_width=img_width, img_height=img_height
)

## Try-out of some prompts

In [None]:
def plot_images(images, title):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.title(title)
        plt.imshow(images[i])
        plt.axis("off")

### "Yoda"

In [None]:
PROMPT = "Yoda"
BATCH_SIZE = 3

images_original = original_model.text_to_image(PROMPT, batch_size=BATCH_SIZE)
images_finetuned = pokemon_model.text_to_image(
    PROMPT, batch_size=BATCH_SIZE, unconditional_guidance_scale=50
)

plot_images(images_original, "original")
plot_images(images_finetuned, "finetuned to pokemon dataset")

### "robotic cat with wings"

In [None]:
PROMPT = "robotic cat with wings"
BATCH_SIZE = 3

images_original = original_model.text_to_image(PROMPT, batch_size=BATCH_SIZE)
images_finetuned = pokemon_model.text_to_image(
    PROMPT, batch_size=BATCH_SIZE, unconditional_guidance_scale=50
)

plot_images(images_original, "original")
plot_images(images_finetuned, "finetuned to pokemon dataset")

### "Girl with a pearl earring"

In [None]:
PROMPT = "Girl with a pearl earring"
BATCH_SIZE = 3

images_original = original_model.text_to_image(PROMPT, batch_size=BATCH_SIZE)
images_finetuned = pokemon_model.text_to_image(
    PROMPT, batch_size=BATCH_SIZE, unconditional_guidance_scale=50
)

plot_images(images_original, "original")
plot_images(images_finetuned, "finetuned to pokemon dataset")

### "Hello Kitty"

In [None]:
PROMPT = "Hello Kitty"
BATCH_SIZE = 3

images_original = original_model.text_to_image(PROMPT, batch_size=BATCH_SIZE)
images_finetuned = pokemon_model.text_to_image(
    PROMPT, batch_size=BATCH_SIZE, unconditional_guidance_scale=50
)

plot_images(images_original, "original")
plot_images(images_finetuned, "finetuned to pokemon dataset")