# Generate images with your fine-tuned Stable Diffusion model

You should use this notebook to interactively generate images, after you've already fine-tuned a stable diffusion model and have a model checkpoint available to load. See the README for instructions.

In [None]:
# TODO: Change this to the path of your fine-tuned model checkpoint!
# This is the $TUNED_MODEL_DIR variable defined in the run script.
TUNED_MODEL_PATH = "/tmp/model-tuned"
# TODO: Set the following variables if you fine-tuned with LoRA.
ORIG_MODEL_PATH = "/tmp/model-orig/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce/"
LORA_WEIGHTS_DIR = "/tmp/model-tuned"

First, load the model checkpoint as a HuggingFace 🤗 pipeline.
Load the model onto a GPU and define a function to generate images from a text prompt.

In [None]:
from os import environ

import torch
from diffusers import DiffusionPipeline

from dreambooth.generate_utils import load_lora_weights, get_pipeline

pipeline = None

def on_full_ft():
    global pipeline
    pipeline = get_pipeline(TUNED_MODEL_PATH)
    pipeline.to("cuda")
    
def on_lora_ft():
    assert ORIG_MODEL_PATH
    assert LORA_WEIGHTS_DIR
    global pipeline
    pipeline = get_pipeline(ORIG_MODEL_PATH, LORA_WEIGHTS_DIR)
    pipeline.to("cuda")


In [None]:
def generate(
    pipeline: DiffusionPipeline,
    prompt: str,
    img_size: int = 512,
    num_samples: int = 1,
) -> list:
    return pipeline([prompt] * num_samples, height=img_size, width=img_size).images


## Try out your model!

Now, play with your fine-tuned diffusion model through this simple GUI.

In [None]:
import time
import ipywidgets as widgets
from IPython.display import display, clear_output

# TODO: When giving prompts, make sure to include your subject's unique identifier,
# as well as its class name.
# For example, if your subject's unique identifier is "unqtkn" and is a dog,
# you can give the prompt "photo of a unqtkn dog on the beach".

# IPython GUI Layouts

output = widgets.Output()
toggle_buttons = widgets.ToggleButtons(
    options=["Full fine-tuning","LoRA fine-tuning"],
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    value=None,
    # layout=widgets.Layout(width='100px')
)

def toggle_callback(change):
    with output:
        clear_output()
        if change["new"] == "Full fine-tuning":
            on_full_ft()
        else:
            on_lora_ft()
        
toggle_buttons.observe(toggle_callback, names="value")
    
input_text = widgets.Text(
    value="photo of a unqtkn dog on the beach",
    placeholder="",
    description="Prompt:",
    disabled=False,
    layout=widgets.Layout(width="500px"),
)

button = widgets.Button(description="Generate!")

# Define button click event
def on_button_clicked(b):
    with output:
        clear_output()
        print("Generating images...")
        print(
            "(The output image may be completely black if it's filtered by "
            "HuggingFace diffusers safety checkers.)"
        )
        start_time = time.time()
        images = generate(pipeline=pipeline, prompt=input_text.value, num_samples=2)
        display(*images)
        finish_time = time.time()
        print(f"Completed in {finish_time - start_time} seconds.")

button.on_click(on_button_clicked)

# Display the widgets
display(toggle_buttons, widgets.HBox([input_text, button]), output)


In [None]:
# release memory properly
del pipeline 
torch.cuda.empty_cache()