# Noise scheduler

When training diffusion models, we gradually apply noise to our image, until the output is normally distributed with $\mathcal{N}(0, 1)$.
The model is trained to invert this process and gradually reconstruct an image from noise.

![brokenimage](/home/marchiorot/Desktop/HY673/Tutorial11/figures/diff_intuition.png)

The original image is $x_0$, while gaussian noise is $x_T$ where $T$ is the number of reconstruction steps that the model will use for sampling (hyperparameter).

During training, we apply $t$ steps of noise to an image in our training data, with $t$ chosen uniformly in $\{1,2, \dots, T\}$.
Core idea:

- Noise gets applied according to $x_{t}\sim \mathcal{N}(\sqrt{1-\beta_t} x_{t-1}, \beta_t)$
- The model's objective is to go back from $x_{t}$ to $x_{t-1}$ (predict one single noise step $\epsilon_t$)

So the loss can be, for example, the MSE between the true $\epsilon_t$ and the model's prediction $\hat{\epsilon}_{t}$

### Problem

Applying noise to get from $x_0$ to $x_t$ is normally an iterative procedure that requires sampling noise from a Gaussian distribution $t$ times, and averaging this noise with the image. 
If we actually did this process in multiple iterations, training large diffusion models would be a very long procedure, sometimes infeasible.

Luckily, some smart mathematician came up with a method to pre-compute noise with one single step.

This method computes the noise parameters at step $t$ which are the $\bar{\alpha}_{t}$ coefficients in the algorithms below.

You can find more details in the slides or reading the original paper: https://arxiv.org/abs/2006.11239

### Trick to apply multiple noise steps

During training (see loss function below), we need to go from $x_0$ to a specific $x_t$.

A single noise step gets applied according to $x_{t}\sim \mathcal{N}(\sqrt{1-\beta_t} x_{t-1}, \beta_t)$. $\beta_1$ and $\beta_T$ are hyperparameters, the other $\beta_t$ values are linearly spaced within the range $[\beta_1, \beta_T]$.

<b>Problem</b>: Producing $x_t$'s sequentially would be computationally expensive.

<b>Trick</b>: There is a formula to get $x_t$ directly from $x_0$
$$
x_t \sim \mathcal{N} (\sqrt{\bar{\alpha}_t} x_0, 1-\bar{\alpha}_t)
$$
with
$$
\alpha_t = 1 - \beta_t
$$
and
$$
\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s \ \ \text{(Cumulative product of $\alpha$'s)}
$$

In [1]:
import torch
import torch.nn as nn

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader

from unet import Unet
from tqdm import tqdm

In [2]:
def get_schedules(beta_1, beta_T, n_T):
    """
    Linear scheduler. 
    Useful to pre-compute all the parameters (even fractions, square roots, etc).
    """

    beta_t = (beta_T - beta_1) * torch.arange(0, n_T + 1, dtype=torch.float32) / n_T + beta_1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrt_abar = torch.sqrt(alphabar_t)
    one_over_sqrt_a = 1 / torch.sqrt(alpha_t)

    sqrt_inv_abar = torch.sqrt(1 - alphabar_t)
    inv_abar_over_sqrt_inv_abar = (1 - alpha_t) / sqrt_inv_abar

    return {
        "alpha": alpha_t,  # \alpha_t
        "one_over_sqrt_a": one_over_sqrt_a,  # 1/\sqrt{\alpha_t}
        "sqrt_beta": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar": alphabar_t,  # \bar{\alpha_t}
        "sqrt_abar": sqrt_abar,  # \sqrt{\bar{\alpha_t}}
        "sqrt_inv_abar": sqrt_inv_abar,  # \sqrt{1-\bar{\alpha_t}}
        "inv_alpha_over_sqrt_inv_abar": inv_abar_over_sqrt_inv_abar,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }



In [3]:
n_T = 1000
betas = [1e-4, 0.02]

schedules = get_schedules(betas[0], betas[1], n_T)

![brokenfig](/home/marchiorot/Desktop/HY673/Tutorial11/figures/training_algo.png)

In [4]:
# Initialize dataset, model, loss function, and optimizer
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
dataset = MNIST("./data", train=True, download=True, transform=transform,
)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

unet = Unet()

loss_fn = nn.MSELoss()

optim = torch.optim.Adam(unet.parameters(), lr=1e-5)

In [5]:
# Step 2
x, _ = next(iter(dataloader))
x = x.view(-1, 1, 28, 28)

# Step 3
timesteps = torch.randint(1, n_T + 1, (x.shape[0],))

# Step 4
eps = torch.randn_like(x)

# Step 5
optim.zero_grad()

x_t = schedules["sqrt_abar"][timesteps, None, None, None] * x + schedules["sqrt_inv_abar"][timesteps, None, None, None] * eps
t = timesteps/n_T
eps_hat = unet(x_t, t)
loss = loss_fn(eps_hat, eps)
loss.backward()

optim.step()


![brokenfig](/home/marchiorot/Desktop/HY673/Tutorial11/figures/sampling_algo.png)

In [6]:
def sample(model, n_T, n_samples, sample_shape, schedules):

    # Step 1
    x_T = torch.randn(n_samples, *sample_shape)

    # Step 2
    x_i = x_T
    for i in tqdm(range(n_T, 0, -1)):
        # Step 3
        z = torch.randn(n_samples, *sample_shape) if i > 1 else 0

        # Step 4
        ts = torch.tensor(i / n_T).repeat(n_samples,)
        eps = model(x_i, ts)
        x_i = schedules["one_over_sqrt_a"][i] * (x_i - eps * schedules["inv_alpha_over_sqrt_inv_abar"][i]) + schedules["sqrt_beta"][i] * z


    # Step 6
    x = x_i
    return x

In [7]:
x = sample(unet, n_T, 8, (1, 28, 28), schedules)

100%|██████████| 1000/1000 [00:14<00:00, 69.33it/s]
