# Denoising Diffusion Probabilistic Models (DDPM) Practical Work

## Imports

This forces Jupyter to reload all `.py` files that you are using on the side. Otherrwise it'll load them once and if you modify the code in the your `.py` files you'll have to reload your kernel for the changes to be reloaded.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install -q -U einops datasets matplotlib tqdm

In [None]:

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F

## Load the Dataset

In [None]:
from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128

In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader

from torchvision.transforms import Compose

# define image transformations (e.g. using torchvision)
transform = Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1),
    ]
)

# define function
def transforms(examples):
    examples["pixel_values"] = [transform(image) for image in examples["image"]]
    del examples["image"]

    return examples


transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(
    transformed_dataset["train"], batch_size=batch_size, shuffle=True
)

## Implement the Denoising Diffusion Process

### Implement the Beta Schedule

#### Linear Beta Schedule

In [None]:
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

#### Cosine Beta Schedule

In [None]:
# timesteps is T
def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps)  # All the timesteps t

    # COMPLETE THIS

    # Clip betas values
    return torch.clip(betas, 0.0001, 0.02)

### Constants

In [None]:
timesteps = 600

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps) # or `cosine_beta_schedule(timesteps=timesteps)`

# define alphas
alphas = ...
alphas_cumprod = ...

# This is just the previous step of the cumulative product above
# It's just alphas_cumprod without the last value and with a 1.0 padding at the beginning
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

sqrt_recip_alphas = ...

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = ...
sqrt_one_minus_alphas_cumprod = ...

# calculations for posterior variance q(x_{t-1} | x_t, x_0)
sigma = ...

### Inference

In [1]:
# This function helps us extract from the array of, for example, all `betas`, the current time step `beta_t`, basically adds the `_t` part our formulas need.
def extract(a, t, x_shape):
    # Get the current batch size
    batch_size = t.shape[0]
    # Get all values from the last axis at the timestep t
    out = a.gather(-1, t.cpu())
    # Reshape the output to the correct dimensions
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

In [None]:
@torch.no_grad()
def p_sample(model, x_t, ts, current_t):
    """
    model: Our model we'll create later
    x_t: The noisy image of current time_step `t`
    ts: All the $t$ for the current time step, basically an array with only `t` times the batch size. Remember that we are always computing our formulas for multiple images at the same time (aka all imaages in the batch).
    current_t: The $t$ integer value from the `ts` array. It's more convenient to have by itself if we want to do the if condition we saw. You could also take the first (or any other) value from the `ts` array, but less convenient.
    """

    # Extract the current time step constants `*_t` here

    # COMPLETE THIS
    sqrt_recip_alphas_t = ...
    betas_t = ...
    sqrt_one_minus_alphas_cumprod_t = ...

    mean_t = ...

    # The condition line 3 in the algorithm
    if current_t == 0:
        # `if t = 0: z = 0` so we can just return the `mean_t`
        return mean_t
    else:
        # COMPLETE THIS
        sigma_t = ...
        z = ...

        return mean_t + ...

In [None]:
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    batch_size = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for t in tqdm(
        reversed(range(0, timesteps)), desc="sampling loop time step", total=timesteps
    ):
        # torch.full: Creates a tensor of size size filled with value i
        img = p_sample(
            model, img, torch.full((batch_size,), t, device=device, dtype=torch.long), t
        )
        imgs.append(img.cpu().numpy())
    return imgs


@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

### Training Loss

In [None]:
from utils import generate_transform_tensor_to_pil_image


# forward diffusion
def q_sample(x_0, ts, noise=None):
    """
    x_0: The original image that we want to add noise to given the specific beta schedule we precomputed above
    ts: All the $t$ for the current time step, basically an array with only `t` times the batch size. Remember that we are always computing our formulas for multiple images at the same time (aka all imaages in the batch).
    """

    if noise is None:
        noise = torch.randn_like(x_0)

    # COMPLETE THIS
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, ts, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, ts, x_0.shape
    )

    # The red rectangle part in our formula
    model_input = ...

    return model_input

In [None]:
# This function is already made for you, it computes the full loss from the training loop above using your implementation of `q_sample` (the red rectangle part)
# You can choose between 3 loss types, "l1", "l2" (or Mean Squared Error (MSE), like in the paper) or "huber" (or smooth l1) loss.
def p_losses(denoise_model, x_0, t, noise=None, loss_type="l1"):
    # The noise `epsilon` in our equation to which we compare our model noise prediction
    if noise is None:
        noise = torch.randn_like(x_0)

    # This is where `q_sample` is being used
    # `x_noisy` is basically our model input
    x_noisy = q_sample(x_0=x_0, t=t, noise=noise)

    # epsilon_theta from our formula in the green rectangle
    predicted_noise = denoise_model(x_noisy, t)

    # The `|| epsilon - epsilon_theta ||^2` part of the equation
    # The derivative part is only computed later in the training loop by PyTorch as we've been doing for all our models up until now
    # You can choose between 3 losses, L2/MSE loss is the one from the paper
    if loss_type == "l1":
        # Same as L1 without the power of 2
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == "l2":
        # The loss in the paper
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        # The Huber loss might be slightly better in this case
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        # If we input any another loss
        raise NotImplementedError()

    # Return the final loss value
    return loss

## Define the model

In [None]:
from torch.optim import AdamW

from model import Unet

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,),
    use_convnext=False, # You can experiment with the other architecture that uses ConvNext
    resnet_block_groups=1, # Set this to 1 for ResNet and 8 for ConvNext
)
model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-3)

## Training loop

In [None]:
from utils import generate_transform_tensor_to_pil_image

reverse_transform = generate_transform_tensor_to_pil_image()

epochs = 20

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        batch_size = batch["pixel_values"].shape[0]
        batch = batch["pixel_values"].to(device)

        # Generate time steps t uniformally (from 0 to timesteps=600 we defined above) for every image in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()

        loss = p_losses(model, batch, t, loss_type="huber")

        if step % 100 == 0:
            print("Loss:", loss.item())

        loss.backward()
        optimizer.step()

    # sample 4 images
    samples = sample(model, image_size=image_size, batch_size=4, channels=channels)

    # show random ones during train
    plt.title(f"Epoch {epoch}, step {step}, loss {loss.item()}")
    for i in range(4):
        plt.imshow(reverse_transform(torch.from_numpy(samples[-1][i])), cmap="gray")
        plt.show()

## Test The Model

In [None]:
# sample images
bs = 32

samples = sample(model, image_size=image_size, batch_size=bs, channels=channels)

# show random ones
for i in range(bs):
    plt.imshow(reverse_transform(torch.from_numpy(samples[-1][i])), cmap="gray")
    plt.show()