In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Defining the forward diffusion process

We first need to build the inputs for our model, which are more and more noisy images. Instead of doing this sequentially, we can use the closed form provided in the papers to calculate the image for any of the timesteps individually.

The forward diffusion process gradually adds noise to an image from the real distribution, in a number of time steps  𝑇 . This happens according to a variance schedule. The original DDPM authors employed a linear schedule:

> We set the forward process variances to constants increasing linearly from  𝛽1=1e−4  to  𝛽𝑇=0.02

However, it was shown in (Nichol et al., 2021) that better results can be achieved when employing a cosine schedule.

Key Takeaways:

- The noise-levels/variances can be pre-computed
- There are different types of variance schedules
- We can sample each timestep image independently (Sums of Gaussians is also Gaussian)
- No model is needed in this forward step

Defining various variance schedules below:

In [2]:
# Functions for various variance schedules
def linear_beta_schedule(timesteps, beta_start = 0.0001, beta_end = 0.02):
    """ linear schedule as used in the ddpm paper """
    return torch.linspace(beta_start, beta_end, timesteps)

def cosine_beta_schedule(timesteps, s=0.008):
    """ cosine schedule as proposed in https://arxiv.org/abs/2102.09672 """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def quadratic_beta_schedule(timesteps, beta_start = 0.0001, beta_end = 0.02):
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps, beta_start = 0.0001, beta_end = 0.02):
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start