# A walk through latent space with Stable Diffusion
Latent space walking, or latent space exploration, is the process of sampling a point in latent space and incrementally changing the latent representation. Its most common application is generating animations where each sampled point is fed to the decoder and is stored as a frame in the final animation. For high-quality latent representations, this produces coherent-looking animations. These animations can provide insight into the feature map of the latent space, and can ultimately lead to improvements in the training process.

### References
- https://keras.io/examples/generative/random_walks_with_stable_diffusion/

In [None]:
import inspect
import math

import torch
from diffusers import (
    StableDiffusionPipeline,
    AutoencoderKL,
    UNet2DConditionModel,
    DDIMScheduler
)
from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from IPython.display import Image as IImage

In [None]:
device = "cuda"
model_path = "CompVis/stable-diffusion-v1-4" # you can download the model weights and save locally
model_path = "/data2/hy/model_weights/stable-diffusion-v1-5/"

In [None]:
# add new method to original StableDiffusionPipeline
class SDPipeline(StableDiffusionPipeline):
    
    @torch.no_grad()
    def encode_text(self, prompt):
        """Encodes prompt into latent text encoding."""
        # get prompt text embeddings
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids

        if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
            removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
            logger.warning(
                "The following part of your input was truncated because CLIP can only handle sequences up to"
                f" {self.tokenizer.model_max_length} tokens: {removed_text}"
            )
            text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
        return text_embeddings
    
    @torch.no_grad()
    def generate_image(
        self,
        text_embeddings,
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        eta=0.0,
        generator=None,
        latents=None,
        output_type="pil"
    ):
        """Generates an image based on text_embeddings."""
        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        batch_size, seq_len, _ = text_embeddings.shape

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0
        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance:
            uncond_tokens = [""] * batch_size

            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = uncond_embeddings.shape[1]

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        # get the initial random noise unless the user supplied it

        # Unlike in other pipelines, latents need to be generated in the target device
        # for 1-to-1 results reproducibility with the CompVis implementation.
        # However this currently doesn't work in `mps`.
        latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
        latents_dtype = text_embeddings.dtype
        if latents is None:
            if self.device.type == "mps":
                # randn does not exist on mps
                latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
                    self.device
                )
            else:
                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
        else:
            if latents.dim() != len(latents_shape):
                raise ValueError(f"Unexpected latents dimension, got {latents.shape}, expected {latents_shape}")
            if latents.shape[0] != batch_size:
                latents = latents.repeat(batch_size, 1, 1, 1)
            latents = latents.to(self.device)

        # set timesteps
        self.scheduler.set_timesteps(num_inference_steps)

        # Some schedulers like PNDM have timesteps as arrays
        # It's more optimized to move all timesteps to correct device beforehand
        timesteps_tensor = self.scheduler.timesteps.to(self.device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma

        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        for i, t in enumerate(self.progress_bar(timesteps_tensor)):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

        latents = 1 / 0.18215 * latents
        image = self.vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)

        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()

        if output_type == "pil":
            image = self.numpy_to_pil(image)
        
        return image

In [None]:
# Define noise scheduler: the parameters must match the original stable diffusion
noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False, # don't clip sample, the x0 in stable diffusion not in range [-1, 1]
    set_alpha_to_one=False,
    steps_offset=1,
)
# load Stable Diffusion Pipeline
pipe = SDPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to(device)

In [None]:
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

## Interpolating between text prompts

In Stable Diffusion, a text prompt is first encoded into a vector, and that encoding is used to guide the diffusion process. The latent encoding vector has shape 77x768 (that's huge!), and when we give Stable Diffusion a text prompt, we're generating images from just one such point on the latent manifold.

To explore more of this manifold, we can interpolate between two text encodings and generate images at those interpolated points:

In [None]:
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
prompt_2 = "A still life DSLR photo of a bowl of fruit"
interpolation_steps = 8
height = 512
width = 512
num_inference_steps = 25

frames_per_second = 2

encoding_1 = pipe.encode_text(prompt_1)
encoding_2 = pipe.encode_text(prompt_2)

interpolated_encodings = torch.cat([
    torch.lerp(encoding_1, encoding_2, weight) for weight in np.linspace(0., 1., interpolation_steps)
], dim=0)


# we generate a latents (noise) to let all generated images have same start noise.
generator = torch.Generator(device).manual_seed(12345)
latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=encoding_1.dtype)

In [None]:
image = pipe.generate_image(
    interpolated_encodings,
    height=height,
    width=width,
    num_inference_steps=num_inference_steps,
    latents=latents)
image[0].save(
    "doggo-and-fruit-8.gif",
    save_all=True,
    append_images=image[1:],
    duration=1000 // frames_per_second,
    loop=0
)

In [None]:
IImage("doggo-and-fruit-8.gif")

In [None]:
grid = image_grid(image, 1, len(image))
grid

To best visualize this, we should do a much more fine-grained interpolation, using hundreds of steps. In order to keep batch size small (so that we don't OOM our GPU), this requires manually batching our interpolated encodings.

In [None]:
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
prompt_2 = "A still life DSLR photo of a bowl of fruit"
interpolation_steps = 128
height = 512
width = 512
batch_size = 8
num_inference_steps = 25

frames_per_second = 8

encoding_1 = pipe.encode_text(prompt_1)
encoding_2 = pipe.encode_text(prompt_2)

interpolated_encodings = torch.cat([
    torch.lerp(encoding_1, encoding_2, weight) for weight in np.linspace(0., 1., interpolation_steps)
], dim=0)

generator = torch.Generator(device).manual_seed(12345)
latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=encoding_1.dtype)

In [None]:
generated_images = []
for i in range(interpolation_steps // batch_size):
    image = pipe.generate_image(
        interpolated_encodings[i*batch_size:(i+1)*batch_size],
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        latents=latents
    )
    generated_images.extend(image)

generated_images[0].save(
    "doggo-and-fruit-128.gif",
    save_all=True,
    append_images=generated_images[1:],
    duration=1000 // frames_per_second,
    loop=0
)

In [None]:
IImage("doggo-and-fruit-128.gif")

We can even extend this concept for more than one image. For example, we can interpolate between four prompts:

In [None]:
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
prompt_2 = "A still life DSLR photo of a bowl of fruit"
prompt_3 = "The eiffel tower in the style of starry night"
prompt_4 = "An architectural sketch of a skyscraper"

height = 512
width = 512
num_inference_steps = 25

interpolation_steps = 6
batch_size = 4

encoding_1 = pipe.encode_text(prompt_1)
encoding_2 = pipe.encode_text(prompt_2)
encoding_3 = pipe.encode_text(prompt_3)
encoding_4 = pipe.encode_text(prompt_4)

interpolated_encodings_12 = torch.cat([
    torch.lerp(encoding_1, encoding_2, weight) for weight in np.linspace(0., 1., interpolation_steps)
], dim=0)
interpolated_encodings_34 = torch.cat([
    torch.lerp(encoding_3, encoding_4, weight) for weight in np.linspace(0., 1., interpolation_steps)
], dim=0)
interpolated_encodings = torch.cat([
    torch.lerp(interpolated_encodings_12, interpolated_encodings_34, weight) for weight in np.linspace(0., 1., interpolation_steps)
], dim=0)


generator = torch.Generator(device).manual_seed(12345)
latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=encoding_1.dtype)

In [None]:
generated_images = []
for i in range(interpolation_steps**2 // batch_size):
    image = pipe.generate_image(
        interpolated_encodings[i*batch_size:(i+1)*batch_size],
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        latents=latents
    )
    generated_images.extend(image)

In [None]:
grid = image_grid(generated_images, interpolation_steps, interpolation_steps)
grid

We can also interpolate while allowing diffusion noise to vary by dropping the `latents` parameter:

In [None]:
generated_images = []
for i in range(interpolation_steps**2 // batch_size):
    image = pipe.generate_image(
        interpolated_encodings[i*batch_size:(i+1)*batch_size],
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
    )
    generated_images.extend(image)

In [None]:
grid = image_grid(generated_images, interpolation_steps, interpolation_steps)
grid

## A walk around a text prompt
Our next experiment will be to go for a walk around the latent manifold starting from a point produced by a particular prompt.

In [None]:
prompt = "The Eiffel Tower in the style of starry night"

height = 512
width = 512
num_inference_steps = 25

walk_steps = 128
step_size = 0.001
batch_size = 8

encoding = pipe.encode_text(prompt)

delta = torch.ones_like(encoding) * step_size

walked_encodings = []
for step_index in range(walk_steps):
    walked_encodings.append(encoding)
    encoding = encoding + delta
walked_encodings = torch.cat(walked_encodings, dim=0)

generator = torch.Generator(device).manual_seed(0)
latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=encoding_1.dtype)

In [None]:
generated_images = []
for i in range( walk_steps // batch_size):
    image = pipe.generate_image(
        walked_encodings[i*batch_size:(i+1)*batch_size],
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        latents=latents
    )
    generated_images.extend(image)

In [None]:
frames_per_second = 8
generated_images[0].save(
    "eiffel-tower-starry-night.gif",
    save_all=True,
    append_images=generated_images[1:],
    duration=1000 // frames_per_second,
    loop=0
)

In [None]:
IImage("eiffel-tower-starry-night.gif")

Perhaps unsurprisingly, walking too far from the encoder's latent manifold produces images that look incoherent. Try it for yourself by setting your own prompt, and adjusting step_size to increase or decrease the magnitude of the walk. Note that when the magnitude of the walk gets large, the walk often leads into areas which produce extremely noisy images.

## A circular walk through the diffusion noise space for a single prompt

Our final experiment is to stick to one prompt and explore the variety of images that the diffusion model can produce from that prompt. We do this by controlling the noise that is used to seed the diffusion process.

We create two noise components, `x` and `y`, and do a walk from 0 to 2π, summing the cosine of our `x` component and the sin of our `y `component to produce noise. Using this approach, the end of our walk arrives at the same noise inputs where we began our walk, so we get a "loopable" result!

In [None]:
prompt = "An oil paintings of cows in a field next to a windmill in Holland"

height = 512
width = 512
num_inference_steps = 25

walk_steps = 128
batch_size = 8

encoding = pipe.encode_text(prompt)

torch.manual_seed(0)
latents_shape = (1, pipe.unet.in_channels, height // 8, width // 8)
walk_noise_x = torch.randn(latents_shape, device=device, dtype=encoding.dtype)
walk_noise_y = torch.randn(latents_shape, device=device, dtype=encoding.dtype)

walk_scale_x = torch.cos(torch.linspace(0, 2, walk_steps) * math.pi).to(device, dtype=encoding.dtype)
walk_scale_y = torch.sin(torch.linspace(0, 2, walk_steps) * math.pi).to(device, dtype=encoding.dtype)
latents_x = walk_scale_x[:, None, None, None] * walk_noise_x
latents_y = walk_scale_y[:, None, None, None] * walk_noise_y
latents = latents_x + latents_y

walked_encodings = encoding.repeat(batch_size, 1, 1)

In [None]:
generated_images = []
for i in range( walk_steps // batch_size):
    image = pipe.generate_image(
        walked_encodings,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        latents=latents[i*batch_size:(i+1)*batch_size]
    )
    generated_images.extend(image)

In [None]:
frames_per_second = 8
generated_images[0].save(
    "cows.gif",
    save_all=True,
    append_images=generated_images[1:],
    duration=1000 // frames_per_second,
    loop=0
)

In [None]:
IImage("cows.gif")