Notebook created by [Sam Dauncey](https://disco.ethz.ch/members/sdauncey) for FS 2025. 

# Generative Computer Vision

"Creating noise from data is easy; creating data from noise is generative modeling."

&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;\- [Yang Song et al.](https://openreview.net/pdf?id=PxTIG12RRHS)

In [8]:
# Collapsed cell:imports
import tqdm

import requests
import os

from copy import deepcopy

import torch
import torch.nn.functional as F
from torch.distributions import Poisson, Uniform
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torch.utils.data import Dataset

import torchvision
from torchvision import transforms

from diffusers import DiffusionPipeline
import peft

assert torch.cuda.is_available(), "This notebook requires a GPU"
device = torch.device("cuda")
print(f"Using device: {device}")


AttributeError: partially initialized module 'torchvision' has no attribute 'extension' (most likely due to a circular import)

In [None]:
# Collapsed cell: helper functions to count the number of parameters in a torch.nn.Module
def count_parameters(module):
    return sum(p.numel() for p in module.parameters())


def parameter_count_string(module):
    n_params = count_parameters(module)
    if n_params > 10**9:
        return f"{n_params/10**9:.1f}B"
    elif n_params > 10**6:
        return f"{n_params/10**6:.1f}M"
    elif n_params > 10**3:
        return f"{n_params/10**3:.1f}k"
    else:
        return f"{n_params}"

## The image generation problem

Why care about AI image generation? Yes, it lets us create cool images, but generation also provides the strongest possible signal for _unsupervised learning_. To faithfully generate an image of a turtle with a city riding its shell the model must first have an internal world model where it understands what turtles and cities are. This differs from supervised learning: unless your dataset has sufficient counter-examples your model can acheive zero loss by learning spurious correlations such as "wolf-labels appear in snowy backgrounds and dog-labels appear in grassy backgrounds".

![turtle with a city riding on its shell](https://polybox.ethz.ch/index.php/s/WDQPJr8etriqjyS/download/turtle_city_stable_diffusion3.png)

A turtle with a city riding on its shell generated by [Stable Diffusion v3.5](https://stability.ai/stable-image)

## Warmup: transposed convolutions

A common theme in image generation is wanting to take some small grid and upsample it into a larger grid. The most common way to do this is using a transposed convolution operator, implemented by the `torch.nn.ConvTranspose2d` class and `torch.nn.functional.conv_transpose2d` function. This figure gives a nice depiction of a transposed convolution  with `kernel_size=3, stride=2, padding=1`:

![Conv_transpose_example](https://polybox.ethz.ch/index.php/s/om9xqyE96p99pRX/download/conv_transpose_example.gif)

Figure from [Vincent Dumoulin, Francesco Visin - A guide to convolution arithmetic for deep learning](https://github.com/vdumoulin/conv_arithmetic/tree/master). 

The `stride = 2` parameter here acts to dialate the input, placing a `0` cell inbetween each pair of spatially adjacent input cells.

See the below example for this in action:

In [None]:
# Example: how does a convolutional transpose behave if we fix its weights?

my_conv_transpose = nn.ConvTranspose2d(1, 1, 3, stride=2, padding=1, bias=False)
filter = torch.arange(9, dtype=torch.float32).reshape(1, 1, 3, 3)
my_conv_transpose.weight.data = filter

x = (10**torch.arange(4, dtype=torch.float32)).reshape(1, 1, 2, 2)
print(f"x = \n{x}\nfilter = \n{filter}\nmy_conv_transpose(x) = \n{my_conv_transpose(x)}")


**Excercise 1:** complete the implementation of the subclass `MyUpsampler`, which should compute a convolutional transpose operation with `kernel_size=3, stride=2, padding=1` (as in the figure). 

Your solution can _not_ use `torch.nn.ConvTranspose2d` or `torch.nn.functional.conv_transpose2d`.

Hint: Your solution _can_ use  [numpy-style striding](https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding) (which can be applied to torch tensors) as well as `torch.nn.functional.conv2d`([docs](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d)) correctly composing these should do the trick.

Hint: Technically transposed convolution flips the filter in the spatial dimensions whereas standard convolution does not, we handle this for you already. 

In [None]:
class MyUpsampler(nn.Module):
    def __init__(self, weight, bias):
        super(MyUpsampler, self).__init__()
        self.weight = weight.flip(2, 3) # Transposed convolution flips the filter in the spatial dimensions, whereas standard convolution does not.
        self.bias = bias

    def forward(self, x):
        b, c, w, h = x.shape
        out = torch.zeros(b, c, 2*w - 1, 2*h- 1)
        out[:, :, ::2, ::2] = x
        out = F.conv2d(F.pad(out, (1, 1, 1, 1)), self.weight, bias=self.bias)
        return  out

# This should be the same as the output of my_conv_transpose(x) above. 
bias = torch.zeros(1, dtype=torch.float32)
us = MyUpsampler(filter, bias)
print(f"us(x) = \n{us(x)}")

## Learning a simple generation process

In this section, we'll build up to the state-of-the-art generative computer vision models using a simple example dataset.

Let's use a training dataset consisting of a grayscale images of a random number of rectangles with random sizes and positions. In order to be able to faithfully generate new images, our model must be able to learn the process that we used to generate them. 

In [None]:
# Collapsed cell: Make a simple dataset

def generate_rectangle_images(batch_size=32):
    # Initialize white background images (batch_size x 1 x 32 x 32)
    images = torch.ones(batch_size, 1, 32, 32)
    
    # Sample number of rectangles per image from Poisson(2) + 1
    n_rectangles = Poisson(torch.tensor(3.0)).sample((batch_size,)).int() + 1
    
    # For each image in the batch
    for i in range(batch_size):
        # Generate n rectangles
        for _ in range(n_rectangles[i].item()):
            # Sample rectangle dimensions
            width = Uniform(2, 10).sample().int().item()
            height = Uniform(2, 10).sample().int().item()
            
            # Sample position (ensure rectangle fits within image)
            x = Uniform(0, 32 - width).sample().int().item()
            y = Uniform(0, 32 - height).sample().int().item()
            
            # Draw black rectangle (value of 0)
            images[i, 0, y:y+height, x:x+width] = 0
            
    return images

In [None]:
# Generate a dataset of 10000 images
class UnlabeledTensorDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

rectangle_images = generate_rectangle_images(10000)
train_dataset = UnlabeledTensorDataset(rectangle_images[1000:])
validation_dataset = UnlabeledTensorDataset(rectangle_images[:1000])

# Plot first 4 images in the dataset
plt.figure(figsize=(8, 8), dpi=64)  # Set dpi to 64
for i in range(4):
    ax = plt.subplot(2, 2, i + 1)
    ax.imshow(train_dataset[i].squeeze(), cmap='gray')
    # Remove ticks
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')  # Ensure aspect ratio is equal
plt.tight_layout(pad=0)  # Remove padding between subplots
plt.show()

Can we take these images $x$ and learn the process which generated them? There is a problem, which we encountered already in the audio notebook: neural networks themselves aren't random, and we cannot simply output logits over all possible images.

### A Naïve approach: decoder image generation

To generate images, a naive approach would be to sample a vector of normally distributed random seeds $z$, which we will call the _latent variables_, and then use a decoder neural network $f^{\theta}$ to map $z$ to an image.

$$x \sim f^{\theta}(z)$$

Then, we could regress latent variables to a dataset of images. Let's try this: in every epoch we'll select a batch of images $x_{1:B}$ and a random set of random seeds $z_{1:B}$ and then try to learn a mapping between the former on the latter using Mean-Square-Error loss 

$$L(x) = \left\lVert x - f^{\theta}(z) \right\rVert^2$$

(Note: in generative computer vision we often use the MSE loss, which is equivalent to the log-likelihood of a standard normal distribution. The papers introducing the concepts in this notebook all use this statistical language when formulating their methods, we will avoid it for simplicity)

![decoder-only](https://polybox.ethz.ch/index.php/s/3Gq9rQaeeKc2je9/download/decoder-only-figure.png)

*The information flow (left to right) of the decoder at train time. $x' = f^{\theta}(z)$ is the reconstructed image.*

In [None]:
latent_dim = 64

In [None]:
# Collapsed cell: define the decoder architecture mapping (batch_size, latent_dim) -> (batch_size, 1, 32, 32)

decoder = nn.Sequential(
    nn.Linear(latent_dim, 64 * 8 * 8),
    nn.Unflatten(1, (64, 8, 8)),
    nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1), # We use output padding to exactly double the size of the spatial dimensions 
    nn.ReLU(),
    nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1),
    nn.ReLU(),
    nn.Conv2d(64, 1, 3, stride=1, padding=1),
    nn.Sigmoid()
).to(device)

In [None]:
print(f"decoder has {parameter_count_string(decoder)} parameters")

**Exercise 2.1:** Fill out the `compute_loss_decoder` function to decode and compute the loss described above. 

In [None]:
def compute_loss_decoder(z, decoder, batch):
    """Returns a torch scalar loss of the mean-squared error between the batch and the decoder reconstruction of the latent variables"""
    # Decode the latent variables to images
    recon_batch = decoder(z)
    # Compute the MSE loss
    loss = F.mse_loss(recon_batch, batch)
    return loss

In [None]:
# Collapsed cell: test your solution to exercise 2

# Test: your solution should give zero loss if the batch is zeros and the decoder returns zeros
def dummy_decoder(z):
    return torch.zeros(z.shape[0], 1, 32, 32)

assert torch.allclose(compute_loss_decoder(torch.randn(5, 64), dummy_decoder, torch.zeros(5, 1, 32, 32)), torch.tensor(0., device=device))

# Test: your solution should give zero loss if the perfect decoder predicts the batch exactly
z = torch.randn(5, latent_dim, device=device)
assert torch.allclose(compute_loss_decoder(z, decoder, decoder(z)), torch.tensor(0., device=device))

# Test: your solution should give the same as the deviation from the "perfect" decoder
z = torch.randn(5, latent_dim, device=device)
noise = torch.randn(5, 1, 32, 32, device=device)
assert torch.allclose(compute_loss_decoder(z, decoder, decoder(z) + noise), (noise**2).mean())

print("All tests passed!")

Now, let's wrap our decoder in a training loop. You should observe the loss almost instantly converging to $\approx 0.095$

In [None]:
# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
decoder.to(device)
optimizer = optim.Adam(decoder.parameters(), lr=1e-3)

# Assuming train_dataset is already defined
dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):

    decoder.train()
    total_loss = 0

    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Generate random latent variables
        z = torch.randn(batch.size(0), latent_dim).to(device)

        loss = compute_loss_decoder(z, decoder, batch)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.5f}')

Now, let's try and generate some images.

**Exercise 2.2**: fill out the `generate_images_decoder` function.

In [None]:
def generate_images_decoder(num_samples, decoder):
    """Returns a tensor of shape (num_samples, 1, 32, 32) representing images """
    # Generate a batch of num_samples latent vectors
    z = torch.randn(num_samples, latent_dim).to(device)
    # Decode these latent vectors into images
    images = decoder(z)
    return images


 If you get an image with a grey blur in the middle on all your generations, don't worry, this is expected!

In [None]:
# Generate random latent variables
num_samples = 20

# Generate images
decoder.eval()
with torch.no_grad():
    generated_images = generate_images_decoder(20, decoder)

# Plot the generated images
plt.figure(figsize=(10, 5))
for i in range(num_samples):
    plt.subplot(2, 10, i+1)
    plt.imshow(generated_images[i].cpu().squeeze().clip(0, 1), cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
plt.tight_layout()
plt.show()

### Autoencoders

So, our approach didn't work. Why? The problem is that neural networks are continuous (they tend to map similar latent variables to similar images), but our training process doesn't account for this. If we sample two similar training images, we would like for the corresponding latent variables to also be similar.

To solve this, let's use another neural network $g^{\phi}$ to learn the inverse mapping from images to latent variables, which should map similar images to similar latents:

$$z = g^{\phi}(x)$$

After considering this pair of neural networks you should notice this looks alot like the autoencoder we covered in the Computer Vision & Audio notebook, in fact it's the same concept! Let's try and use it for generating images. What we'll do is choose a much smaller latent space and then minimize the difference between the original image $x$ and the reconstructed image $f^{\theta}(g^{\phi}(x))$:

$$\mathcal{L}(x) = \left\lVert x - f^{\theta}(g^{\phi}(x)) \right\rVert^2 $$

Hopefully, we should then be able to sample new images by decoding $z$ values that we sample randomly.

![autoencoder](https://polybox.ethz.ch/index.php/s/SMX7wGAC3BGJGW4/download/autoencoder-figure.png)

*The information flow (left to right) of an autoencoder at train time.*

In [None]:
# Collapsed cell: define the Autoencoder architecture with .encode and .decode methods mapping (batch_size, 1, 32, 32) -> (batch_size, latent_dim) -> (batch_size, 1, 32, 32)
class AutoEncoder(nn.Module):
    def __init__(self, latent_dim=64):
        super(AutoEncoder, self).__init__()
        
        # Encoder
        self.encoder_backbone = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        
        # Calculate flattened size
        self.fc_z = nn.Linear(64 * 4 * 4, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64 * 4 * 4),
            nn.Unflatten(1, (64, 4, 4)),
            nn.ConvTranspose2d(64, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        return self.fc_z(self.encoder_backbone(x))
        
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        return self.decode(self.encode(x))

In [None]:
autoencoder = AutoEncoder().to(device)
print(f"autoencoder has {parameter_count_string(autoencoder)} parameters")

**Exercise 3.1:** fill out the `compute_loss_autoencoder` function to compute the loss as described above.

In [None]:
def compute_loss_autoencoder(autoencoder, batch):
    """Returns a torch scalar loss of the mean-squared error between the batch and the decoder reconstruction of the latent variables"""
    # Encode the images as latent variables
    z = autoencoder.encode(batch)
    # Decode the latent variables to images
    recon_batch = autoencoder.decode(z)
    # Compute the MSE loss
    loss = F.mse_loss(recon_batch, batch)
    return loss

In [None]:
# Collapsed cell: test your solution to exercise 3.1

class DummyAutoencoder(nn.Module):
    def __init__(self):
        super(DummyAutoencoder, self).__init__()

    def encode(self, x):
        return x

    def decode(self, z):
        return z

    def forward(self, x):
        return x

# Test: your solution should give zero loss if the batch is zeros and the autoencoder returns zeros
dummy_autoencoder = DummyAutoencoder()

test_batch = torch.randn(5, 1, 32, 32)
assert torch.allclose(compute_loss_autoencoder(dummy_autoencoder, test_batch), torch.tensor(0., device=device))

# Test: your solution should give zero loss when given the perfect autoencoder.
noise = torch.randn(5, 1, 32, 32)
assert torch.allclose(compute_loss_autoencoder(dummy_autoencoder, noise), torch.tensor(0., device=device))

# Test: your solution should give the same loss as the deviation from the "perfect" autoencoder.
class NoisyAutoencoder(nn.Module):
    def __init__(self, noise):
        self.noise = noise
        super(NoisyAutoencoder, self).__init__()

    def encode(self, x):
        return x

    def decode(self, z):
        return z + self.noise

    def forward(self, x):
        return self.decode(self.encode(x))

noise = torch.randn(5, 1, 32, 32)
test_batch = torch.randn(5, 1, 32, 32)

noisy_autoencoder = NoisyAutoencoder(noise)

assert torch.allclose(compute_loss_autoencoder(noisy_autoencoder, test_batch), (noise**2).mean())

print("All tests passed!")

Let's train our autoencoder:

In [None]:
# Training setup
optimizer = optim.Adam(autoencoder .parameters(), lr=1e-4) # We need a smaller learning rate to prevent instabilities.
dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    autoencoder.train()
    total_loss = 0
    for batch in dataloader:

        batch = batch.to(device)
        optimizer.zero_grad()
        
        # Total loss
        loss = compute_loss_autoencoder(autoencoder, batch)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    
    if (epoch + 1) % 10 == 0:
        # Calculate validation loss
        autoencoder.eval()
        val_loss = 0
        val_dataloader = DataLoader(validation_dataset, batch_size=128, shuffle=True)
        with torch.no_grad():
            for batch in val_dataloader:

                batch = batch.to(device)
                loss = compute_loss_autoencoder(autoencoder, batch)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_dataloader)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

The loss decreasing is a much more positive sign. Again, let's generate some images by sampling iid standard normal latents.

**Exercise 3.2:** fill out the `generate_images_autoencoder` function. 

Hint: `torch.randn` ([docs](https://docs.pytorch.org/docs/stable/generated/torch.randn.html)) can be used to sample iid standard normal tensors of arbitrary shapes. 

In [None]:
def generate_images_autoencoder(num_samples, autoencoder):
    """Returns a tensor of shape (num_samples, 1, 32, 32) representing images."""
    # Generate a batch of num_samples latent vectors normally distributed
    z = torch.randn(num_samples, latent_dim).to(device)
    # Decode these latent vectors into images
    images = autoencoder.decode(z)
    return images

Your solution should generate gray blurs with some deviations in them.

In [None]:
# Generate random latent variables
num_samples = 20

# Generate images
autoencoder.eval()
with torch.no_grad():
    generated_images = generate_images_autoencoder(20, autoencoder)

# Plot the generated images
plt.figure(figsize=(10, 5))
for i in range(num_samples):
    plt.subplot(2, 10, i+1)
    plt.imshow(generated_images[i].cpu().squeeze().clip(0, 1), cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
plt.tight_layout()
plt.show()

So sampling random latent variables didn't allow us to generate images like our dataset. As we saw in the Computer Vision and Audio notebook, standard autoencoders _do_ learn to reconstruct images that are like what they see at train time. With images, we can plot this:

In [None]:
# Collapsed cell: visualize autoencoder reconstructions of images in the validation set
autoencoder.eval()  # Set to evaluation mode

# Get a batch of images from the validation set
val_dataloader = DataLoader(validation_dataset, batch_size=10, shuffle=True)
original_images = next(iter(val_dataloader)).to(device)

# Generate reconstructions
with torch.no_grad():
    reconstructed_images = autoencoder(original_images)

# Plot original vs reconstructed images
plt.figure(figsize=(10, 5))

# Plot original images
for i in range(5):
    plt.subplot(2, 5, i + 1)
    plt.imshow(original_images[i].cpu().squeeze(), cmap='gray')
    plt.title("Original")
    plt.axis('off')

# Plot reconstructed images
for i in range(5):
    plt.subplot(2, 5, i + 6)
    plt.imshow(reconstructed_images[i].cpu().squeeze(), cmap='gray')
    plt.title("Reconstructed")
    plt.axis('off')

plt.tight_layout()
plt.show()


### _Variational_ AutoEncoders (VAEs)

What we can see with the above autoencoder results is that the decoder from an autoencoder _does_ learn to reconstruct unseen images when given some very specific latent variables. The problem is that we have no clue of knowing which latent variables to choose to elicit these generations!

What we can do is _force_ our autoencoder to make its latents have some nice statistical property (such as being normally distributed). This is precisely what _Variational AutoEncoders (VAEs)_ do. To get this to work, they add some noise to the encoder's latent prediction, by making it to output a mean and variance of a distribution of latents.

$$\mu, \sigma^2 = g^{\phi}(x)$$

$$z \sim \mathcal{N}(\mu, diag(\sigma^2))$$

Now, we can make the latent variables normally distributed by incentivising the encoder to output $\mu, \sigma$ closer to $0, I$. The new KL term in the right hand side of the below loss formula does just this (don't worry if you don't know the precise formula for this term). We need to add two tricks before this completely works. 

$$\mathcal{L}(x) = \frac{1}{2 b^2} \left\lVert x - f^{\theta}(\texttt{sample}(g^{\phi}(x))) \right\rVert^2 + \ln(b) + KL(g^{\phi}(x) \Vert \mathcal{N}(0,I))$$

The additional $b$ is to balance the mean square error term and the KL term, and it represents the model's uncertainty in its reconstruction. We will initialise it to $0.5$ and the let the model learn it's value. 

This loss formula is called the Evidence Lower BOund (ELBO), in this notebook we won't go into its precise derivation.

![VAE](https://polybox.ethz.ch/index.php/s/yNcTpzfcKHks7qa/download/vae-figure.png)

*The information flow (left to right) of a VAE at train time.*

In [None]:
# Collapsed cell: define the VAE architecture by extending the AutoEncoder class.
class VAE(AutoEncoder):
    def __init__(self, latent_dim=64):
        super(VAE, self).__init__(latent_dim)
        self.fc_mu = self.fc_z
        self.fc_var = nn.Linear(64 * 4 * 4, latent_dim)
        self.b = nn.Parameter(torch.tensor(.1)) # The initial variance of the reconstruction

    def encode(self, x):
        x = self.encoder_backbone(x)
        mu = self.fc_mu(x)
        # In reality, we model the log of the variance, not the variance itself for numerical stability
        log_var = self.fc_var(x) 
        return mu, log_var
    
    def sample_z(self, mu, log_var):
        std = torch.exp(0.5 * log_var) # 1/2 as standard deviation is sqrt of variance
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.sample_z(mu, log_var)
        return self.decode(z), mu, log_var

In [None]:
# Create a VAE to train
model = VAE().to(device)
print(f"VAE has {parameter_count_string(model)} parameters")

**Exercise 4.1:** fill out the `compute_loss_vae` function. Note that the losses will only be balanced if you take the **sum** of the squared errors (rather than the mean, as has been done in previous sections).

In [None]:
def compute_loss_vae(model, batch):
    # Forward pass
    recon_batch, mu, log_var = model(batch)
    b = model.b
    # Reconstruction loss sum of squared errors
    sum_squared_loss = torch.sum((recon_batch - batch) ** 2)
    # KL divergence loss (given for free)
    kl_loss = 0.5 * torch.sum(log_var.exp() + mu.pow(2) - log_var - 1)
    # Total loss
    loss = sum_squared_loss / (2*b*b) + kl_loss + torch.log(b)
    return loss

In [None]:
# Collapsed cell: test your solution
class DummyVAE(nn.Module):
    def __init__(self, mu, log_var, b, recon_x):
        super(DummyVAE, self).__init__()
        self.mu = mu
        self.log_var = log_var
        self.b = b
        self.recon_x = recon_x

    def forward(self, x):
        return self.recon_x, self.mu, self.log_var

# Test: if the image is perfectly reconstructed with b=1., the loss should be the KL divergence

test_mu = torch.randn(2, 64).to(device)
test_log_var = torch.randn(2, 64).to(device)
test_b = torch.tensor(1.).to(device)
test_x = torch.randn(2, 1, 32, 32).to(device)

dummy_model = DummyVAE(test_mu, test_log_var, test_b, test_x).to(device)

kl_term = 0.5 * torch.sum(test_log_var.exp() + test_mu.pow(2) - test_log_var - 1)

assert torch.allclose(compute_loss_vae(dummy_model, test_x), kl_term)

# Test: if the latent distribution is standard normal and b=1, the loss should be half the _sum_ of the squared error.

test_mu = torch.zeros_like(test_mu)
test_log_var = torch.zeros_like(test_log_var)
test_x_recon = torch.randn(2, 1, 32, 32).to(device)

dummy_model = DummyVAE(test_mu, test_log_var, test_b, test_x_recon).to(device)

assert torch.allclose(compute_loss_vae(dummy_model, test_x), torch.sum((test_x - test_x_recon)**2) / 2)

# Test: extend the above test to check for b!= 1

test_b = torch.tensor(0.73).to(device)

dummy_model = DummyVAE(test_mu, test_log_var, test_b, test_x_recon).to(device)

assert torch.allclose(compute_loss_vae(dummy_model, test_x), torch.sum((test_x - test_x_recon)**2) / (2*test_b**2) + torch.log(test_b))

print("All tests passed!")


While waiting for your VAE to train, you can read the below section *Aside: Generative Adversarial Networks*

In [None]:
# Create a VAE to train
model = VAE().to(device)
print(f"VAE has {parameter_count_string(model)} parameters")

# Training setup
optimizer = optim.Adam(model.parameters(), lr=1e-3)
dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in dataloader:

        optimizer.zero_grad()

        batch = batch.to(device)
        loss = compute_loss_vae(model, batch)
        
        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_dataset)
    
    if (epoch + 1) % 10 == 0:

        # Calculate validation loss
        model.eval()
        val_loss = 0
        val_dataloader = DataLoader(validation_dataset, batch_size=128, shuffle=True)
        with torch.no_grad():
            for batch in val_dataloader:

                batch = batch.to(device)
                loss = compute_loss_vae(model, batch)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(validation_dataset)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}')



Now, let's generate some images from the VAE. Your results should be much better than the autoencoder, but still have a quite a bit of blurriness.

**Exercise 4.2:** write the `generate_images_vae` function.

Hint: you may not need to rewrite the function.

In [None]:
generate_images_vae = generate_images_autoencoder

In [None]:

# Generate some samples
model.eval()
with torch.no_grad():
    samples = generate_images_vae(4, model)
    
# Plot generated samples
plt.figure(figsize=(8, 8), dpi=64)
for i in range(4):
    ax = plt.subplot(2, 2, i + 1)
    ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')
plt.tight_layout(pad=0)
plt.show()

Let's plot some reconstructions from the VAE too:

In [None]:
# select a random image from the dataset
image = train_dataset[torch.randint(0, len(train_dataset), (1,))].to(device)

# reconstruct the image with the VAE
with torch.no_grad():
    reconstructed, _, _ = model(image)

# display the original and reconstructed images
plt.figure(figsize=(8, 8), dpi=128)
plt.subplot(1, 2, 1)
plt.imshow(image.cpu().squeeze(), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.subplot(1, 2, 2)
plt.imshow(reconstructed.cpu().squeeze(), cmap='gray')
plt.xticks([])
plt.yticks([])

plt.show()

### Aside: Generative Adversarial Networks

*Reading this section is optional but encouraged.*

Until ~2020, the way to solve the blurriness problem of the VAE was to use a Generative Adversarial Network (GAN). Instead of trying to learn the inverse mapping from images to latent variables, GANs would instead learn a discriminator $\delta$ which could tell real images from generated images. The discriminator would be trained using labels indicating whether an image is from the training dataset or generated by $f^\theta$, and the generator would be trained using the gradient of the discriminator backpropagated onto the pixels of the generated images.

One problem with GANs is that they are hard to validate: with a VAE you can compare the loss on the train and validation datasets, but with a GAN there exists no such way to pass a dataset into the generator. You _can_ evaluate the most extreme form of overfitting by searching for exact copies of generated images in the train dataset, but this does not account for overfitting to certain motifs in the train dataset. GANs also tend to struggle with very diverse datasets: the generator gets stuck only generating a few of the many classes in the data.

![GAN](https://polybox.ethz.ch/index.php/s/6yp84KePe4PGBsc/download/GAN-figure.png)

*The information flow (left to right) of a GAN at train time.*

### Diffusion Models

One way one can understand the bluriness of VAEs is to realise that the model internally may know that the image contains hard edges, but with a single step generation process it must hedge its bet when reconstructing the image and predict a blur around where it believes the edge should be. Additionally, our VAE isn't very sample efficient: it only has ~500k parameters, produces low-fidelity samples and has still slightly overfit to the training data.

_Diffusion models_ solve these problems by turning generation into a multi-step process: the model partially reconstructs the image, looks at the image again, reconstructs it a bit more, looks at the image again, and so on. The fact that the model is trained on so many partial reconstructions makes the effective training dataset much larger, meaning that we can use a larger model with more compute without overfitting.

Specifically, at train time we pick a random "timestep" $t$ in $0, 1, 2 \dots 1000$, and then combine a training image $x$ with some normally distributed noise $\epsilon$ to get some noised image $x_t$. For this notebook, we will use a simple linear combination:

$$x_t = \left(1 - \frac{t}{1000} \right)x + \frac{t}{1000} \epsilon$$

Thus, at $t=0$ the input $x_0$ is the original image $x$, whereas at time $t=1000$ the input $x_{1000}$ is pure noise. 

The idea is that we can train the model to reverse this process: given a noised image $x_t$, can we remove a small amount of noise to get a slightly less noisy image $x_{t-1}$?

In [None]:
# Collapsed cell: plotting the flow x_t process at different timesteps
plt.figure(figsize=(20, 5))
image = train_dataset[5]

# Generate noise
epsilon = torch.randn_like(image)


for i, t in enumerate(range(0, 1100, 100)):
    t = min(t, 1000)
    # Mix image and noise according to timestep
    x_t = (1 - t/1000) * image + (t/1000) * epsilon
    
    plt.subplot(1, 11, i+1)
    plt.imshow(x_t.cpu().squeeze(), cmap='gray')
    plt.title(f't={t}')
    plt.xticks([])
    plt.yticks([])

plt.tight_layout()
plt.show()


Our formulation does tell us that $x_{t-1}$ differs from $x_t$ by a small step in the direction in pixel space $\epsilon - x$:

$$x_{t-1} = x_t - \frac{1}{1000}(\epsilon - x)$$

The brilliant thing is that we can use a neural network $f^{\theta}$ to take in $x_t$ for some sample of $x, t, \epsilon$ and output its best approximation of the direction $\epsilon - x$, which we evaluate with MSE.

$$\mathcal{L}(x) = \left\lVert (\epsilon - x) - f^{\theta} \left(x_t, \frac{t}{1000}\right) \right\rVert^2$$

If we let $v = \epsilon - x$ and $v_t' = f^{\theta}(x_t, \frac{t}{1000})$, the below figure describes the information flow at train time:

![Diffusion](https://polybox.ethz.ch/index.php/s/6KyiNLMMcpEizaG/download/diffusion-figure.png)

*The information flow (left to right) of a diffusion model at train time.*

Then to generate new samples, we can start with some noise for $x_{1000}'$, get our network's prediction $v_{1000}' = f^{\theta}(x_{1000}', 1)$ for the model's best guess of the direction $v$, denoise our sample a small amount using the network's prediction to retrieve a slightly less noisy sample, our best approximation of what $x_{999}'$ would have been, and then repeat:

$$x_{t-1}' = x_{t}' - \frac{1}{1000} f^{\theta} \left(x_{t}', \frac{t}{1000}\right)$$



On the Diffusion model architecture: 
- It can be much larger than the VAE's without overfitting. 
- As we don't need latents to be squashed through a bottleneck, we can use the u-net architecture which we covered in the Computer Vision and Audio notebook. (Technically this is _possible_ with VAEs, but it requires adding alot of complexity with a heirachy of latent variables).

In [None]:
# Define the architecture for our diffusion model. We can make the model much larger than the VAE for the reasons discussed above, and we can also 
class DiffusionUNet(nn.Module):
    def __init__(self, base_dim=32, n_channels=1, n_updown_blocks=4, n_middle_blocks=2):
        """base_dim is the number of channels after the first convolution, n_channels is the number of channels in the input image"""
        super(DiffusionUNet, self).__init__()
        
        self.n_channels = n_channels

        self.down_blocks = nn.ModuleList()
        self.up_blocks = nn.ModuleList()

        self.conv_in = nn.Sequential(
            nn.Conv2d(n_channels+1, base_dim, kernel_size=3, padding=1),
            nn.GroupNorm(8, base_dim),
            nn.ReLU()
        )

        for i in range(n_updown_blocks):
            # Encoder layers
            in_channels = base_dim * (2**i)
            out_channels = base_dim * (2**(i+1))
            
            self.down_blocks.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                nn.GroupNorm(8, out_channels),
                nn.ReLU()
            ))
            
            # Decoder layers
            dec_in_channels = out_channels * 2  # We double the number of channels for skip connections
            dec_out_channels = in_channels
            
            self.up_blocks.append(nn.Sequential(
                nn.ConvTranspose2d(dec_in_channels, dec_out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.GroupNorm(8, dec_out_channels),
                nn.ReLU(),
                nn.Conv2d(dec_out_channels, dec_out_channels, kernel_size=1, stride=1, padding=0), # Add a 1x1 conv to make future adaptations easier
                nn.GroupNorm(8, dec_out_channels),
                nn.ReLU()
            ))


        mid_block_width = base_dim * 2**(n_updown_blocks)

        self.middle = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(mid_block_width, mid_block_width, 3, padding=1),
                nn.GroupNorm(8, mid_block_width),
                nn.ReLU()
            ) 
        for _ in range(n_middle_blocks)])

        self.conv_out = nn.Conv2d(2*base_dim, n_channels, kernel_size=3, padding=1)
    
    
    def forward(self, x_t, t):
        # Concatenate timestep as another channel in the image
        t = t.expand(-1, 1, x_t.shape[2], x_t.shape[3])
        x = torch.cat([x_t, t], dim=1)

        x = self.conv_in(x)

        x_in = x

        # Store the hidden states as we downsample
        hidden_states = [x]
        for layer in self.down_blocks:
            x = layer(x)
            hidden_states.append(x)

        # Use residual connections in the middle
        for layer in self.middle:
            x = x + layer(x)

        # Use skip connections with the corresponding stored hidden states
        for layer, hidden_state in zip(reversed(self.up_blocks), reversed(hidden_states)):
            x = torch.cat([x, hidden_state], dim=1)
            x = layer(x)

        # Concatenate the original image with the stored hidden states
        x = torch.cat([x, x_in], dim=1)
        x = self.conv_out(x)
        
        return x

There are many ways of formulating a diffusion processes by combining $(x, \epsilon, t)$ to get a loss and sampling process, the one we present above is called *rectified flow* or *optimal transport*, it's relatively mathematically simple and was used to train Stable Diffusion 3.

**Exercise 5.1:** fill out the `compute_loss_rectified_flow` function. 

In [None]:
def compute_loss_rectified_flow(model, x):
    # Sample a batch of shape (batch_size, 1, 1, 1) of random timesteps uniformly from 0 to 1000
    batch_size = x.shape[0]
    time = torch.randint(0, 1000, (batch_size, 1, 1, 1)).to(device)
    
    # Create a batch of the same shape as x of normally distributed noise
    noise = torch.randn_like(x)

    # Create a batch of noised images x_t according to rectified flow process
    x_t = (1 - time/1000) * x + (time/1000) * noise

    # Predict direction v_t using the model
    v_t = model(x_t, time/1000)

    # Compute the true direction v
    v = noise - x

    # Rectified flow loss: MSE between predicted and true direction (this time, you can use the mean)
    loss = F.mse_loss(v_t, v)
    return loss


In [None]:
# Collapsed cell: test your compute_loss_rectified_flow function
class DummyDiffusionUNet(nn.Module):
    def __init__(self, true_x, error):
        super(DummyDiffusionUNet, self).__init__()
        # Memorize the true x and some model error term we will add to the prediction
        self.true_x = true_x
        self.error = error

    def forward(self, x_t, t):
        # Test: the timestep given to the model should contain values in the range [0, 1]
        assert torch.all(t >= 0) and torch.all(t <= 1)
        
        # Compute the true noise using the memorized true x
        true_noise = (1/t) * (x_t - self.true_x * (1 - t))

        # Test: the true noise should have mean 0 and std 1
        # Ideally would use a Goodness-of-Fit (eg. KS test here) to test normality but I doubt students will use the wrong distribution.
        standard_error = 1 / true_noise.numel() **.5
        assert -4 < true_noise.mean()/standard_error < 4

        standard_error = 1 / (2* true_noise.numel()) **.5
        assert -4 < (true_noise.std() - 1)/standard_error < 4

        # Compute the true direction v using the memorized true x
        true_v = (1/t) * (x_t - self.true_x)

        # Return the true direction v plus the predetermined model error term for more tests
        return true_v + self.error

# Test: if the model can compute the true direction v, the loss should be 0

test_x = torch.randn(8, 1, 64, 64).to(device)
test_error = torch.zeros(8, 1, 64, 64).to(device)

my_model = DummyDiffusionUNet(test_x, test_error)

assert torch.allclose(compute_loss_rectified_flow(my_model, test_x), torch.tensor(0.))

# Test: if there is some error in the model, the loss should be the mean square of this error

test_error = torch.randn(8, 1, 64, 64).to(device)

my_model = DummyDiffusionUNet(test_x, test_error)

assert torch.allclose(compute_loss_rectified_flow(my_model, test_x), torch.mean(test_error**2))

print("All tests passed!")


In [None]:
def validate_rectified_flow(model, val_dataloader):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for val_batch in val_dataloader:
            
            val_x = val_batch.to(device)
            loss = compute_loss_rectified_flow(model, val_x)
            val_loss += loss.item()
            
    val_loss /= len(val_dataloader)
    model.train()
    return val_loss


# Training loop with rectified flow objective
def train_rectified_flow(model, dataloader, val_dataloader, num_epochs=100, lr=1e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            optimizer.zero_grad()
            
            # Get data and move to device
            x = batch.to(device)
            
            loss = compute_loss_rectified_flow(model, x)

            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
        avg_loss = total_loss / len(dataloader)
        if epoch % 10 == 0:
            # Calculate validation loss
            val_loss = validate_rectified_flow(model, val_dataloader)
            print(f'Epoch {epoch}, Train MSE: {avg_loss:.4f}, Val MSE: {val_loss:.4f}')

The training loop should take ~5 mins to run. While it's running, you can read the below section on conditional image generation.


In [None]:
diffusion_unet = DiffusionUNet(base_dim=32)

print(f"DiffusionUNet has {parameter_count_string(diffusion_unet)} parameters")

diffusion_unet.to(device)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(validation_dataset, batch_size=128, shuffle=True)

train_rectified_flow(diffusion_unet, train_dataloader, val_dataloader)


You should see the model doesn't overfit, despite having 20x more parameters than the VAE. This means that, with our diffusion model, if we had more time and compute, we could a train larger model for longer and get even better results.

Now, let's generate some images from the rectified flow. Your results won't perfectly match the dataset and there will still be some artefacts, but it should be much better.

**Exercise 5.2:** fill out the `generate_images_rectified_flow` function.

In [None]:
def generate_images_rectified_flow(model, num_samples):
    # Initialise the sample from random noise
    x = torch.randn(num_samples, 1, 32, 32).to(device)
    
    # Sample using rectified flow (reverse process)
    # Loop over the timesteps [1000, 999, 998, ..., 1]
    for step in range(1000, 0, -1):
        # Make a timestep tensor of shape (num_samples, 1, 1, 1)
        timestep = torch.full((num_samples, 1, 1, 1), step).to(device)
        
        # Predict noise
        noise = model(x, timestep/1000)
        
        # Update sample using rectified flow equation
        x = x - 1/1000 * noise

    return x

In [None]:

# Generate and visualize samples
num_samples = 4
with torch.no_grad():
    samples = generate_images_rectified_flow(diffusion_unet, num_samples)

# Plot generated samples
plt.figure(figsize=(8, 8), dpi=64)
for i in range(4):
    ax = plt.subplot(2, 2, i + 1)
    ax.imshow(samples[i].cpu().squeeze().clip(0, 1), cmap='gray', vmin=0, vmax=1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')
plt.tight_layout(pad=0)
plt.show()

## Practical Diffusion Models

Now we know the principles of operation for diffusion models, which are the general class of models which dominate the state-of-the-art as of 2025, let's see how we can actually use them in practice.


### Conditional Image Generation

Say that we have a dataset of different classes of images, or even images with textual captions. To share we don't want to re-train a new set of model weights for each class, so how do we do this? Perhaps suprisingly, a statistically valid way to do this is simply to _input a representation of the class_ into the networks composing our model. 

Specifically, any of the above models can be made class conditional by taking all of the component networks $h^{\theta}(x, ...)$ and adding the class information $c$ to get $h^{\theta}(x, c. ...)$. At train and test time, we pass the class information $c$ into all the components uncorrupted. This works because, at the abstract math level, the class information essentially has the same effect as an additional set of model weights $\theta_c = (\theta, c)$.

#### Learning to generate MNIST

To see how diffusion generalises to real-world datasets, let's train a model to generate MNIST digits:

In [None]:
# Note: for Google Colab/local machine, you can just set cache_dir to "./"
import os
import getpass
user = getpass.getuser()
cache_dir = os.path.join("/scratch", user)
os.makedirs(cache_dir, exist_ok=True)

# Load MNIST dataset
mnist_train = torchvision.datasets.MNIST(root=cache_dir, train=True, transform=transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root=cache_dir, train=False, transform=transforms.ToTensor(), download=True)

valid_size = len(mnist_train) // 10

mnist_valid = torch.utils.data.Subset(mnist_train, range(valid_size))
mnist_train = torch.utils.data.Subset(mnist_train, range(valid_size, len(mnist_train)))

mnist_train_dataloader = DataLoader(mnist_train, batch_size=128, shuffle=True)
mnist_valid_dataloader = DataLoader(mnist_valid, batch_size=128, shuffle=True)

In [None]:
# Display some examples of the MNIST dataset
plt.figure(figsize=(10, 5))
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.imshow(mnist_train[i][0].squeeze(), cmap='gray')
    plt.axis('off')
plt.tight_layout()

Passing the class information as input can be done in many ways, for large models conditioned text we often use cross-attention, but for our model we'll use the `nn.Embedding` class ([docs](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html)). What we'll do is wrapping our pre-existing `DiffusionUNet` class to add some extra channels to our image, where these channels are the class embedding copied across the width and height of the image.

**Exercise 6:** Finish the `ConditionalDiffusionUNet.forward` function.

In [None]:
class ConditionalDiffusionUNet(DiffusionUNet):
    def __init__(self, base_dim=32, n_channels=1, n_updown_blocks=4, n_middle_blocks=2, n_classes=None, class_embedding_dim=16):
        # We will use concatenate the class embedding as some extra channels in the image

        super(ConditionalDiffusionUNet, self).__init__(base_dim=base_dim, n_channels=n_channels+class_embedding_dim, n_updown_blocks=n_updown_blocks, n_middle_blocks=n_middle_blocks)

        self.class_embedding = nn.Embedding(n_classes, class_embedding_dim)

        # we need to reset this as it is incorrectly set by the superclass constructor
        self.n_channels_img = n_channels
        self.n_classes = n_classes

    def forward(self, x, t, c):
        # Here c is a tensor of shape (batch_size, ) with integer values in the range [0, n_classes)
        # Embed the class information
        c = self.class_embedding(c)

        # Copy the embedding across the width and height of the image
        c = c.unsqueeze(2).unsqueeze(3)
        c = c.expand(-1, -1, x.shape[2], x.shape[3])

        # Concatenate the embedding with the image in the channel dimension
        x = torch.cat([x, c], dim=1)

        x = super().forward(x, t)

        # Remove the extra channels
        x = x[:, :self.n_channels_img, :, :]

        return x

In [None]:
# As MNIST images are 28x28, we can only downsample twice before we get to 7x7, so we use 2 updown blocks to avoid complications
mnist_unet = ConditionalDiffusionUNet(base_dim=32, n_classes=10, n_updown_blocks=2, n_middle_blocks=8).to(device)
print("mnist_unet has", parameter_count_string(mnist_unet), "parameters")

In [None]:
# Collapsed cell: test your solution to Exercise 6
test_x = torch.randn(5, 1, 28, 28).to(device)
test_c = torch.randint(0, 10, (5,)).to(device)
test_t = torch.ones(5, 1, 1, 1).to(device) * 500
out = mnist_unet(test_x, test_t, test_c)
assert out.shape == (5, 1, 28, 28)

**Exercise 7.1** adapt your solution to Exercise 5.1 to complete the `compute_loss_rectified_flow_conditional` function

In [None]:
def compute_loss_rectified_flow_conditional(model, x, c):
    # Sample a batch of shape (batch_size, 1, 1, 1) of random timesteps uniformly from 0 to 1000
    batch_size = x.shape[0]
    time = torch.randint(0, 1000, (batch_size, 1, 1, 1)).to(device)
    
    # Create a batch of the same shape as x of normally distributed noise
    noise = torch.randn_like(x)

    # Create a batch of noised images x_t according to rectified flow process
    x_t = (1 - time/1000) * x + (time/1000) * noise

    # Predict direction v_t using the model
    v_t = model(x_t, time/1000, c)

    # Compute the true direction v
    v = noise - x

    # Rectified flow loss: MSE between predicted and true direction (this time, you can use the mean)
    loss = F.mse_loss(v_t, v)
    return loss

Now, let's train our MNIST model for $20$ epochs:

In [None]:
# Collapsed cell: adapted training loops from above to be class-conditional

def validate_rectified_flow_conditional(model, val_dataloader):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for val_batch in val_dataloader:
            val_x, val_c = val_batch
            val_x = val_x.to(device)
            val_c = val_c.to(device)
            
            loss = compute_loss_rectified_flow_conditional(model, val_x, val_c)
            
            val_loss += loss.item()
            
    val_loss /= len(val_dataloader)
    model.train()
    return val_loss


# Training loop with rectified flow objective
def train_rectified_flow_conditional(model, dataloader, val_dataloader, num_epochs=100, lr=1e-4):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            optimizer.zero_grad()
            
            # Get data and move to device
            x, c = batch
            x = x.to(device)
            c = c.to(device)
            
            loss = compute_loss_rectified_flow_conditional(model, x, c)

            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
        avg_loss = total_loss / len(dataloader)
        if epoch % 10 == 0:
            # Calculate validation loss
            val_loss = validate_rectified_flow_conditional(model, val_dataloader)
            print(f'Epoch {epoch}, Train MSE: {avg_loss:.4f}, Val MSE: {val_loss:.4f}')

In [None]:
train_rectified_flow_conditional(mnist_unet, mnist_train_dataloader, mnist_valid_dataloader, num_epochs=21)

**Exercise 7.2** adapt your solution to Exercise 5.2 to complete the `generate_images_rectified_flow_conditional` function. Make sure to initialise your image with spatial dimension $28 \times 28$!

In [None]:
def generate_images_rectified_flow_conditional(model, c, num_samples):
    # Initialise the sample from random noise
    x = torch.randn(num_samples, 1, 28, 28).to(device)
    
    # Sample using rectified flow (reverse process)
    # Loop over the timesteps [1000, 999, 998, ..., 1]
    for step in range(1000, 0, -1):
        # Make a timestep tensor of shape (num_samples, 1, 1, 1)
        timestep = torch.full((num_samples, 1, 1, 1), step).to(device)
        
        # Predict noise
        noise = model(x, timestep/1000, c)
        
        # Update sample using rectified flow equation
        x = x - 1/1000 * noise

    return x

In [None]:
n_mnist_samples = 20
# Use assending order of classes as the condition, with wrapping around.
c_mnist_samples = torch.arange(n_mnist_samples).fmod(10).to(device)

with torch.no_grad():
    samples = generate_images_rectified_flow_conditional(mnist_unet, c_mnist_samples, num_samples=n_mnist_samples)

# Plot generated samples
plt.figure(figsize=(10, 5))
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.imshow(samples[i].cpu().squeeze().clip(0, 1), cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
plt.tight_layout()

It's getting there, but still a bit under-baked. Let's load one that has been trained using the above code for 100 epochs, and see the result:

In [None]:
# Collapsed cell: loading a pretrained mnist_unet_100_epochs model

import os
import getpass
user = getpass.getuser()
cache_dir = os.path.join("/scratch", user)
os.makedirs(cache_dir, exist_ok=True)

unet_save_path = os.path.join(cache_dir, "mnist_unet_100_epochs_model.pth")

if not os.path.exists(unet_save_path):
    unet_url = "https://polybox.ethz.ch/index.php/s/wjBkDgWSAMZJj6T/download/mnist_unet_100_epochs_model.pth"
    print(f"Model not found at {unet_save_path}, downloading from {unet_url}")
    response = requests.get(unet_url)
    with open(unet_save_path, "wb") as f:
        f.write(response.content)
    print(f"Downloaded model to {unet_save_path}")

# Load the saved model
# Note: for Google Colab/local machine you may need to set the weights_only flag to False
mnist_unet_100_epochs = torch.load(unet_save_path)
mnist_unet_100_epochs.to(device)  # Move model to GPU
mnist_unet_100_epochs.eval()  # Set to evaluation mode
print(f"Loaded model from {unet_save_path}")

In [None]:
with torch.no_grad():
    samples = generate_images_rectified_flow_conditional(mnist_unet_100_epochs, c_mnist_samples, num_samples=n_mnist_samples)

# Plot generated samples
plt.figure(figsize=(10, 5))
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.imshow(samples[i].cpu().squeeze().clip(0, 1), cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
plt.tight_layout()

## State-of-the art generative computer vision

We don't have the time \& compute to train the latest diffusion models from scratch, so we'll use the 🤗 HuggingFace 🧨 Diffusers [library](https://huggingface.co/docs/diffusers/index) to load a pre-trained models and play with them.


In [None]:
# Clear GPU
for name in dir():
    if not name.startswith('_'):
        del globals()[name]

import gc
import torch
gc.collect()
torch.cuda.empty_cache()

# Move huggingface cache to scratch space to avoid filling up /net-scratch/
import os
import getpass
user = getpass.getuser()
cache_dir = os.path.join("/scratch", user)
os.makedirs(cache_dir, exist_ok=True)
cache_dir = os.path.join(cache_dir, "hugging_face")
os.makedirs(cache_dir, exist_ok=True)
os.environ["HF_HOME"] = os.path.abspath(cache_dir)
os.environ["HF_HUB_CACHE"] = os.path.abspath(cache_dir)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from diffusers import DiffusionPipeline
import peft

import matplotlib.pyplot as plt

from copy import deepcopy

import torchvision
from torchvision import transforms


The model we'll load is StableDiffusionXL, which was [released](https://arxiv.org/pdf/2307.01952) in 2023. It uses _latent diffusion_: an autoencoder downsamples the image to some latents and then diffusion is performed on these latents rather than the image itself. At generation time, the diffusion process contains another noise-denoise process (called refinement). Instead of conditioning on a class label, the generation is conditioned on a text prompt entered by the user, which is inputted into the diffusion UNet with cross-attention.

![StableDiffusionXL](https://polybox.ethz.ch/index.php/s/WCNSocn9LbCnfJC/download/stablediffusionXL-architecture.png)

Figure credit: [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/pdf/2307.01952), StabilityAI. *ArXiV. 2023*

You need $\approx 6GB$ of VRAM free to run the stable diffusion XL example in the notebook. To fit this on the GPU we're going to use `fp16` precision which uses 2 bytes/parameter (rather than the 4 bytes/parameter of `fp32`) and we're going to configure HuggingFace to load parts of the model in \& out.

In [None]:
device = torch.device("cuda")
def display_gpu_memory():
  total_memory_GB = torch.cuda.get_device_properties(device).total_memory / 1024**3
  allocated_memory_GB = torch.cuda.memory_allocated(device) / 1024**3
  remaining_gpu_memory_GB = (total_memory_GB - allocated_memory_GB) 
  print(f"Remaining GPU memory: {remaining_gpu_memory_GB:.2f}/{total_memory_GB:.2f} GB")

display_gpu_memory()

In [None]:
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
pipe.enable_model_cpu_offload()

def count_parameters(module):
    return sum(p.numel() for p in module.parameters())


def parameter_count_string(module):
    n_params = count_parameters(module)
    if n_params > 10**9:
        return f"{n_params/10**9:.1f}B"
    elif n_params > 10**6:
        return f"{n_params/10**6:.1f}M"
    elif n_params > 10**3:
        return f"{n_params/10**3:.1f}k"
    else:
        return f"{n_params}"
print(f"pipeline contains a vae, text encoder and unet with {parameter_count_string(pipe.vae)}, {parameter_count_string(pipe.text_encoder)} and {parameter_count_string(pipe.unet)} parameters respectively")

By default, HuggingFace wraps the model in a nice pipeline abstraction, making inference easy:

In [None]:
prompt = "ETH Zurich main building"
with torch.no_grad():
    image = pipe(prompt).images[0]
    
image

### Low Rank Adaptation (LoRA)

We see that our model knows what a rough university campus building looks like, but doesn't know what the ETH Zurich main building looks like. How could we (in theory) change this?

We could collect a bunch of images of the ETH Zurich main building and use them to finetune our model. This has problems:

- We would need $\approx 3 \times$ the GPU RAM as we use in inference to store the gradients and the rest of the optimizer state.
- We would have to collect $>100$ images to make sure our model doesn't overfit.

What we can do instead is freeze our model parameters train an _adapter_: a small part that we train and add to our model for a specific task. The adapter can have many fewer parameters than the base model, making it less prone to overfit and making storing the gradients much less memory intensive.

Low Rank Adapters (LoRAs) are a type of adapters which run in parrallel to each linear map in the base model. Each LoRA adaptation of a given linear map is composed of two smaller linear maps: $A$ which maps the input of the linear map to a small number of entries and $B$ which maps the small number of entries to a vector which we add to the base output. We then freeze the base model and initialize $B = 0$ so that the adapted model behaves the same as the base model until we start to update $A$ \& $B$ . 

Figure 1 of the seminal paper gives a good depiction of this process:

![LoRA](https://polybox.ethz.ch/index.php/s/GrMoyPRiNqneDKp/download/LoRA_Hu_Fig1.png)

Figure credit: [LoRA: Low-Rank Adaptation of Large Language Models](https://openreview.net/forum?id=nZeVKeeFYf9), Hu et al. *ICLR 2022.*

The frozen model parameters are blue, the trainable LoRA parameters are orange.

### LoRA toy problem: regularizing linear regression 

To see LoRA in action in a toy problem, we're going to use it to train a linear mapping $W$ between two noisy vectors with `512` entries each. Let's generate such a dataset with only `50` examples.

(aside: solving this problem may not be actually _too_ far from what LoRA does on large models, there is [some evidence](https://transformer-circuits.pub/2022/toy_model/index.html#motivation-directions) that common large neural networks represent concepts as linear combinations of vectors)

In [None]:
# Toy problem setup:
embedding_dim = 512

mu_x = torch.randn(1, embedding_dim)
mu_y = torch.randn(1, embedding_dim)

n_train = 50

n_test = 50

# Create a batch of 50 training examples.
x_train = 0.3 * torch.randn(n_train, embedding_dim) + mu_x
y_train = 0.3 * torch.randn(n_train, embedding_dim) + mu_y

# Create a batch of 50 test examples.
x_test = 0.3 * torch.randn(n_test, embedding_dim) + mu_x
y_test = 0.3 * torch.randn(n_test, embedding_dim) + mu_y

W = nn.Linear(embedding_dim, embedding_dim, bias=False)

print("W has", parameter_count_string(W), "parameters")

To visualize these data, heres a scatter plot of the first two entries of the train dataset vectors and their means:

In [None]:
def plot_toy_problem(W, x_train, y_train, x_test, y_test):
    plt.figure(figsize=(10, 5))
    plt.scatter(x_train[:, 0], x_train[:, 1], c='blue', label='Training examples')
    plt.scatter(mu_x[0, 0], mu_x[0, 1], c='black', label='Mean of x')
    plt.scatter(y_train[:, 0], y_train[:, 1], c='red', label='Training examples')
    plt.scatter(mu_y[0, 0], mu_y[0, 1], c='black', label='Mean of y')
    plt.legend()
    plt.show()

plot_toy_problem(W, x_train, y_train, x_test, y_test)

If we naively finetune on this dataset using SGD we see that our model dramatically overfits and uses the same number of trainable parameters as $W$:

In [None]:
def full_finetune(W, x_train, y_train, x_test, y_test):

    print(f"W has {parameter_count_string(W)} parameters")
    optimizer = torch.optim.Adam(W.parameters(), lr=1e-3)
    
    for epoch in range(200):
        optimizer.zero_grad()

        y_train_pred = W(x_train)

        train_loss = F.mse_loss(y_train_pred, y_train)
        train_loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            y_test_pred = W(x_test)
            test_loss = F.mse_loss(y_test_pred, y_test)

        if epoch % 20 == 0:
            print(f"Epoch {epoch}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}")


W_copy = deepcopy(W)
full_finetune(W_copy, x_train, y_train, x_test, y_test)

**Exercise 8:** Fill in the `lora_finetune` function by implementing LoRA.

In [None]:
def lora_finetune(W, x_train, y_train, x_test, y_test):

    # Initialize A and B as nn.Linear layers
    A = nn.Linear(embedding_dim, embedding_dim, bias=False)
    B = nn.Linear(embedding_dim, embedding_dim, bias=False)
   
    # Set B's weights to zero (the default initialization of A is fine)
    B.weight.data.zero_()
    print(f"A has {parameter_count_string(A)} parameters, B has {parameter_count_string(B)} parameters")

    # Create a list of the parameters we want to update at each step
    optimization_parameters = list(A.parameters()) + list(B.parameters())

    optimizer = torch.optim.Adam(optimization_parameters, lr=1e-3)

    for epoch in range(200):
        optimizer.zero_grad()

        # Add the LoRA adapter output to the prediction
        y_train_pred = W(x_train) + A(x_train) @ B.weight.T

        train_loss = F.mse_loss(y_train_pred, y_train)
        train_loss.backward()
        optimizer.step()

        with torch.no_grad():
            # Add the LoRA adapter output to the prediction
            y_test_pred = W(x_test) + A(x_test) @ B.weight.T
            test_loss = F.mse_loss(y_test_pred, y_test)

        if epoch % 20 == 0:
            print(f"Epoch {epoch}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}")

W_copy = deepcopy(W)
lora_finetune(W, x_train, y_train, x_test, y_test)

To test your solution to Exercise 8, we will use the  🤗 HuggingFace Parameter-Efficient Fine-Tuning (PEFT) [library](https://huggingface.co/docs/peft/index), which allows you to easily put LoRA into some pre-specified parts of an `nn.Module`. If your solution works, it should produce very similar results (converged losses within $10^{-3}$) to the below code:

In [None]:
W_wrapper = nn.Sequential(deepcopy(W))
W_wrapper

lora_config = peft.LoraConfig(
    task_type=None,
    inference_mode=False,
    r=4,
    lora_alpha=4,
    target_modules="0"
)

W_lora = peft.get_peft_model(W_wrapper, lora_config)
W_lora.print_trainable_parameters()

full_finetune(W_lora, x_train, y_train, x_test, y_test)

NameError: name 'nn' is not defined

For large models, we can train a separate LoRA for each linear map in the model. We can in fact make the adapted model use the exact same amount of compute as the base model by simply adding the matrices $W_{lora} = W_{base} + BA$ at the end of training.

One great thing is that HuggingFace contains many LoRA adapters that have been pre-finetuned, and these adapters are relatively lightweight as the $A$ and $B$ matrices are often much smaller than the base model layers. [Here's one](https://huggingface.co/TheLastBen/Papercut_SDXL) which adds a papercut effect to the generation. Let's try generating a papercut-style model of the ETH Zurich main building using the original model, and then compare it with the one generated using the LoRA model.

In [None]:
prompt = "papercut ETH Zurich main building"
with torch.no_grad():
    image = pipe(prompt).images[0]
    
image

In [None]:
pipe.load_lora_weights("TheLastBen/Papercut_SDXL")

In [None]:
# The LoRA is trained to be activated when the prompt contains "papercut"
prompt = "papercut Gobbledigoober"

with torch.no_grad():
    image = pipe(prompt).images[0]
    
image