# Generate images with your fine-tuned Stable Diffusion model

You should use this notebook to interactively generate images, after you've already fine-tuned a stable diffusion model and have a model checkpoint available to load. See the README for instructions.

In [None]:
# TODO: Change this to the path of your fine-tuned model checkpoint!
# This is the $TUNED_MODEL_DIR variable defined in the run script.
TUNED_MODEL_PATH = "/tmp/model-tuned"


First, load the model checkpoint as a HuggingFace ðŸ¤— pipeline.
Load the model onto a GPU and define a function to generate images from a text prompt.

In [None]:
import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    TUNED_MODEL_PATH, torch_dtype=torch.float16
)
pipeline.to("cuda")


In [None]:
def generate(
    pipeline: DiffusionPipeline,
    prompt: str,
    img_size: int = 512,
    num_samples: int = 1,
) -> list:
    return pipeline([prompt] * num_samples, height=img_size, width=img_size).images


## Try out your model!

Now, play with your fine-tuned diffusion model through this simple GUI.

In [None]:
import time
import ipywidgets as widgets
from IPython.display import display, clear_output

# TODO: When giving prompts, make sure to include your subject's unique identifier,
# as well as its class name.
# For example, if your subject's unique identifier is "unqtkn" and is a dog,
# you can give the prompt "photo of a unqtkn dog on the beach".

# IPython GUI Layouts
input_text = widgets.Text(
    value="photo of a unqtkn dog on the beach",
    placeholder="",
    description="Prompt:",
    disabled=False,
    layout=widgets.Layout(width="500px"),
)

button = widgets.Button(description="Generate!")
output = widgets.Output()

# Define button click event
def on_button_clicked(b):
    with output:
        clear_output()
        print("Generating images...")
        print(
            "(The output image may be completely black if it's filtered by "
            "HuggingFace diffusers safety checkers.)"
        )
        start_time = time.time()
        images = generate(pipeline=pipeline, prompt=input_text.value, num_samples=2)
        display(*images)
        finish_time = time.time()
        print(f"Completed in {finish_time - start_time} seconds.")

button.on_click(on_button_clicked)

# Display the widgets
display(widgets.HBox([input_text, button]), output)
