<a href="https://colab.research.google.com/github/r-isachenko/2024-DGM-MIPT-YSDA-course/blob/main/seminars/seminar12/seminar12_SD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

It is a slightly changed notebook from [Understanding Stable Diffusion from "Scratch"](https://scholar.harvard.edu/binxuw/classes/machine-learning-scratch/materials/stable-diffusion-scratch) by Binxu Wang and John Vastola.

It has many sections with loose order between them. You can:
* Play with generating art from prompt.
* See the effect of the parameters for generating process.
* Visualizing the diffusion process and latents
* Looking under the hood of the sampling function.
* Inspect the internal network architecture of the components of Stable Diffusion.

# Stable Diffusion Playground

In [1]:
!pip install diffusers transformers tokenizers

Make sure you have a runtime with GPU!

In [None]:
import torch
assert torch.cuda.is_available()
!nvidia-smi

## Loading Stable Diffusion

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt

def plt_show_image(image):
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
assert torch.cuda.is_available()
!nvidia-smi

Here `fp16` checkpoint is loaded just to save memory and compute time. if you have a great gpu, you can remove the line `revision="fp16", torch_dtype=torch.float16`.

**StableDiffusionPipeline** [doc](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview)

**CompVis/stable-diffusion-v1-4** [doc](https://huggingface.co/CompVis/stable-diffusion-v1-4)



In [None]:
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    StableDiffusionInpaintPipeline,
)

text2img = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token=True,
    revision="fp16", torch_dtype=torch.float16
).to("cuda")
# Disable the safety checkers
def dummy_checker(images, **kwargs): return images, [False] * images.shape[0]
text2img.safety_checker = dummy_checker


#text2img = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
inpaint = StableDiffusionInpaintPipeline(**text2img.components)

In [6]:
pipe = text2img

## Generative Playground

In [None]:
prompt = "a lovely cat running in the desert in Van Gogh style, trending art."
image = pipe(prompt).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat.png")
image

In [None]:
image #  "a lovely cat running in the desert in Van Gogh style, trending art."

### Fixing the random seed

In [None]:
generator = torch.Generator("cuda").manual_seed(1024)

prompt = "a sleeping cat enjoying the sunshine."
image = pipe(prompt, generator=generator).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat_sun.png")
image

### Changing (Denoising) Diffusion steps

In [None]:
prompt = "a sleeping cat enjoying the sunshine."
image = pipe(prompt, num_inference_steps=25).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat_sun.png")
image

### Adding Negative prompt

Adding negative prompt can control what you do not want.

In [None]:
prompt = "a sleeping cat enjoying the sunshine."
image = pipe(prompt, generator=generator,
               negative_prompt="tree and leaves").images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat_sun_no_trees.png")
image

## Visualizing the Diffusion in Action

First import some utils for showing videos in colab.

In [None]:
# https://colab.research.google.com/github/google/mediapy/blob/main/mediapy_examples.ipynb#scrollTo=u0kuKXep2pfr
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import itertools
import math
import mediapy as media

In [13]:
!mkdir diffprocess

In [14]:
image_reservoir = []
latents_reservoir = []


@torch.no_grad()
def plot_show_callback(i, t, latents):
    latents_reservoir.append(latents.detach().cpu())
    # magic constant? see https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515
    image = pipe.vae.decode(1 / 0.18215 * latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
    # plt_show_image(image)
    plt.imsave(f"diffprocess/sample_{i:02d}.png", image)
    image_reservoir.append(image)


@torch.no_grad()
def save_latents(i, t, latents):
    latents_reservoir.append(latents.detach().cpu())


@torch.no_grad()
def saveimg_callback(i, t, latents):
    latents_reservoir.append(latents.detach().cpu())
    image = pipe.vae.decode(1 / 0.18215 * latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
    # plt_show_image(image)
    plt.imsave(f"diffprocess/sample_{i:02d}.png", image)
    image_reservoir.append(image)

These callback functions will save the image in the process into a list `image_reservoir` and latents into `latents_reservoir`.

In [None]:
prompt = "a lovely cat running in the desert in Van Gogh style, trending art."
with torch.no_grad():
  image = pipe(prompt, callback=plot_show_callback, callback_steps=1).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat.png")
image
# video1 = media.moving_circle((65, 65), num_images=10)
media.show_video(image_reservoir, fps=5)

In [None]:
prompt = "a lovely cat running in the desert in Van Gogh style, trending art."
image_reservoir = []
with torch.no_grad():
    image = pipe(prompt, callback=plot_show_callback, callback_steps=1).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat2.png")
image
media.show_video(image_reservoir, fps=5)

In [None]:
prompt = "a lovely cat running in the desert in Van Gogh style, trending art."
image_reservoir = []
with torch.no_grad():
    image = pipe(prompt, callback=plot_show_callback, callback_steps=1).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat3.png")
image
media.show_video(image_reservoir, fps=5)

### Visualizing Image sequence

In [None]:
# video1 = media.moving_circle((65, 65), num_images=10)
media.show_video(image_reservoir, fps=5)

### Visualizing latents

 What about the latents? How do they change in the diffusion process?

In [None]:
latents_reservoir[0].shape

Since we have 4 channel in the latent tensor, we can choose to visualize any 3 of them as RGB. You can put any number in 0,1,2,3 in the `Chan2RGB` list. see what it visualize

In [None]:
Chan2RGB = [0,1,2]
latents_np_seq = [tsr[0,Chan2RGB].permute(1,2,0).numpy() for tsr in latents_reservoir]

In [None]:
media.show_video(latents_np_seq, fps=5)

## Write a simple text2img sampling function

Here I provide a simplified version of the sampling function! See what happened under the hood when you run `pipe(prompt)`

Feel free to print out tensors and record their shape within this function!

In [None]:
@torch.no_grad()
def generate_simplified(
    prompt = ["a lovely cat"],
    negative_prompt = [""],
    num_inference_steps = 50,
    guidance_scale = 7.5):
    # do_classifier_free_guidance
    batch_size = 1
    height, width = 512, 512
    generator = None
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.

    # get prompt text embeddings
    text_inputs = pipe.tokenizer(
        prompt,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    text_embeddings = pipe.text_encoder(text_input_ids.to(pipe.device))[0]
    bs_embed, seq_len, _ = text_embeddings.shape

    # get negative prompts  text embedding
    max_length = text_input_ids.shape[-1]
    uncond_input = pipe.tokenizer(
        negative_prompt,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]

    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
    seq_len = uncond_embeddings.shape[1]
    uncond_embeddings = uncond_embeddings.repeat(batch_size, 1, 1)
    uncond_embeddings = uncond_embeddings.view(batch_size, seq_len, -1)

    # For classifier free guidance, we need to do two forward passes.
    # Here we concatenate the unconditional and text embeddings into a single batch
    # to avoid doing two forward passes
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # get the initial random noise unless the user supplied it
    # Unlike in other pipelines, latents need to be generated in the target device
    # for 1-to-1 results reproducibility with the CompVis implementation.
    # However this currently doesn't work in `mps`.
    latents_shape = (batch_size, pipe.unet.in_channels, height // 8, width // 8)
    latents_dtype = text_embeddings.dtype
    latents = torch.randn(latents_shape, generator=generator, device=pipe.device, dtype=latents_dtype)

    # set timesteps
    pipe.scheduler.set_timesteps(num_inference_steps)
    # Some schedulers like PNDM have timesteps as arrays
    # It's more optimized to move all timesteps to correct device beforehand
    timesteps_tensor = pipe.scheduler.timesteps.to(pipe.device)
    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * pipe.scheduler.init_noise_sigma

    # Main diffusion process
    for i, t in enumerate(pipe.progress_bar(timesteps_tensor)):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
        # predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        # compute the previous noisy sample x_t -> x_t-1
        latents = pipe.scheduler.step(noise_pred, t, latents, ).prev_sample

    latents = 1 / 0.18215 * latents
    image = pipe.vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
    return image

In [None]:
image = generate_simplified(
    prompt = ["a lovely cat"],
    negative_prompt = ["Sunshine"],)
plt_show_image(image[0])

In [None]:
image = generate_simplified(
    prompt = ["a cat dressed like a ballerina"],
    negative_prompt = [""],)
plt_show_image(image[0])

## Image to Image Translation Playground

In [23]:
pipe = img2img

In [None]:
import requests
from io import BytesIO
from PIL import Image

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((768, 512))
init_img

In [None]:
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device="cuda").manual_seed(1024)
with autocast("cuda"):
    image = pipe(prompt=prompt, image=init_img,
                 strength=0.75, guidance_scale=7.5,
                 generator=generator).images[0]

image

## Write a simple img2img sampling function

In [None]:
import numpy as np
# Preprocess the initial image
def preprocess(image):
    if isinstance(image, Image.Image):
        w, h = image.size
        w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8
        image = image.resize((w, h), resample=Image.LANCZOS)
        image = np.array(image).astype(np.float32) / 255.0
        image = image[None].transpose(0, 3, 1, 2)  # add batch dimension and rearrange
        image = torch.from_numpy(image)
        image = 2.0 * image - 1.0  # scale to [-1, 1]
    return image


@torch.no_grad()
def generate_img2img_simplified(
    init_image: Image.Image,
    prompt=["A fantasy landscape, trending on artstation"],
    negative_prompt=[""],
    strength=0.5,
    batch_size=1,
    num_inference_steps=50,
    guidance_scale=7.5,
    generator=None,
    **extra_step_kwargs):

    do_classifier_free_guidance = True

    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    # set timesteps
    pipe.scheduler.set_timesteps(num_inference_steps)

    # get prompt text embeddings
    text_inputs = pipe.tokenizer(
        prompt,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    text_embeddings = pipe.text_encoder(text_input_ids.to(pipe.device))[0]

    # get unconditional embeddings for classifier free guidance
    uncond_tokens = negative_prompt
    max_length = text_input_ids.shape[-1]
    uncond_input = pipe.tokenizer(
        uncond_tokens,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]

    # For classifier free guidance, we need to do two forward passes.
    # Here we concatenate the unconditional and text embeddings into a single batch
    # to avoid doing two forward passes
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # encode the init image into latents and scale the latents
    latents_dtype = text_embeddings.dtype
    if isinstance(init_image, Image.Image):
        init_image = preprocess(init_image)
    init_image = init_image.to(device=pipe.device, dtype=latents_dtype)
    init_latent_dist = pipe.vae.encode(init_image).latent_dist
    init_latents = init_latent_dist.sample(generator=generator)
    init_latents = 0.18215 * init_latents

    # get the original timestep using init_timestep
    offset = pipe.scheduler.config.get("steps_offset", 0)
    init_timestep = int(num_inference_steps * strength) + offset
    init_timestep = min(init_timestep, num_inference_steps)

    timesteps = pipe.scheduler.timesteps[-init_timestep]
    timesteps = torch.tensor([timesteps] * batch_size, device=pipe.device)

    # add noise to latents using the timesteps
    noise = torch.randn(init_latents.shape, generator=generator, device=pipe.device, dtype=latents_dtype)
    init_latents = pipe.scheduler.add_noise(init_latents, noise, timesteps)

    latents = init_latents

    t_start = max(num_inference_steps - init_timestep + offset, 0)
    # Some schedulers like PNDM have timesteps as arrays
    # It's more optimized to move all timesteps to correct device beforehand
    timesteps = pipe.scheduler.timesteps[t_start:].to(pipe.device)

    for i, t in enumerate(pipe.progress_bar(timesteps)):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

    latents = 1 / 0.18215 * latents
    image = pipe.vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()

    return image

-----------------------

In [None]:
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device="cuda").manual_seed(1024)
with autocast("cuda"):
    image = img2img(prompt=prompt, image=init_img,
                 strength=0.75, guidance_scale=7.5,
                 generator=generator).images[0]

image

In [None]:
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device="cuda").manual_seed(1024)
with autocast("cuda"):
    image = generate_img2img_simplified(prompt=prompt, init_image=init_img,
                 strength=0.75, guidance_scale=7.5,
                 generator=generator)
plt_show_image(image[0])