# INSA, GMM Image
## Practical sessions: Introduction to Diffusion

Welcome in this practical session on diffusion models.  
Diffusion models are a class of generative models that have been used to generate images, videos, and audio. They are based on the idea of gradually adding noise to an image and then removing the noise to get back the original image.  
Since the training of diffusion models is computationally expensive, we will use a toy dataset to illustrate the concept of diffusion.  
We will use a 2d points dataset to illustrate the concept of diffusion and then we will apply it to a pretrained image model.  
Let's start by downloading and visualizing the dataset.

In [None]:
!wget https://github.com/tanelp/tiny-diffusion/raw/refs/heads/master/static/DatasaurusDozen.tsv

In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from typing import List

device = 'cuda' if torch.cuda.is_available() else 'cpu'

df = pd.read_csv('DatasaurusDozen.tsv',sep='\t')
datasaurus = (np.asarray(df[df['dataset']=='dino'][['x','y']].values, dtype=float))
datasaurus = (datasaurus - datasaurus.mean())/datasaurus.std()
plt.scatter(datasaurus[:,0],datasaurus[:,1])

We will use this dataset to illustrate the concept of diffusion.  
This data will be our initial distribution, the one we want to sample from.  
Since we don't know how to sample from this distribution, we will use a noise schedule to gradually add noise to the data until we reach a Gaussian distribution for which we know how to sample.  
Then we will use the reverse process to sample from the initial distribution.  

### Noise Schedule
Let's begin by defining the noise schedule.  
The noise schedule is a function that defines the amount of noise to add at each timestep.  
We will use a linear schedule for this example.  

In [None]:
T=200
alpha_min=0.0001
alpha_max=0.05
alphas = torch.linspace(alpha_min, alpha_max, T)
alphas = 1. - alphas
alpha_bar = torch.cumprod(alphas, dim=0)
plt.figure(figsize=[6, 6])
plt.plot(torch.arange(T), alpha_bar)
plt.xlabel('Timestep')
plt.ylim(0,1.05)

### Forward Pass
The forward pass is the process of adding noise to the data at each timestep.  
We will use the noise schedule to define the amount of noise to add at each timestep.  
We saw in class that the forward step can be written as:
$$x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1-\alpha_t}\epsilon_t$$
where $\alpha_t$ is the amount of noise to add at timestep $t$ and $\epsilon_t$ is a noise sample.  
Complete the following functions to perform the forward pass.

In [None]:
def forward_step(x_t_minus_1:torch.Tensor, alphas:torch.Tensor, t:int, eps:torch.Tensor) -> torch.Tensor:
    """
    Takes the previous step, the alphas and the timestep and returns the next step
    args:
        x_t_minus_1: the previous step
        alphas: the alphas of the noise schedule
        t: the timestep
        eps: the noise sample
    returns:
        x_t: the next step
    """
    x_t = alphas[t].sqrt()*x_t_minus_1 + (1-alphas[t]).sqrt()*eps
    return x_t

def forward_pass(x_0:torch.Tensor, alphas:torch.Tensor, T:int=200) -> List[torch.Tensor]:
    """
    Takes the initial data, the alphas and the number of timesteps and returns the list of forward steps
    args:
        x_0: the initial data
        alphas: the alphas of the noise schedule
        T: the number of timesteps
    returns:
        x_series: a list of the forward steps
    """
    x_series = [x_0]
    for t in range(T):
        eps = torch.randn_like(x_0)
        new_x = forward_step(x_series[-1],alphas,t,eps) # we use the alphas of the noise schedule
        #new_x = forward_step(x_series[-1],torch.cumprod(alphas,dim=0),t,eps) # alpha_bar ?
        x_series.append(new_x)
    return x_series

x_0 = torch.tensor(datasaurus).repeat(6, 1) # we repeat the data 6 times to have more points to visualize
x_series = forward_pass(x_0, alphas, T)

Now plot the different steps of the forward pass for t = [0, 6, 12, 25, 50]

In [None]:
figure = plt.figure(figsize=(20, 4))
for i, t in enumerate([0, 6, 12, 25, 50]):
    dataset, time_step = x_series[t], t
    figure.add_subplot(1,5,i+1)
    plt.title(time_step)
    plt.axis("off")
    plt.scatter(dataset[:,0],dataset[:,1],s=15,alpha=0.5)

You should observe that the data becomes more and more noisy as the timestep increases.  
The following code animates the forward pass and allows you to see the data evolve over time.


In [None]:
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
from IPython.display import HTML
from functools import partial

fig, ax = plt.subplots(figsize=(6, 6))

def animate(i:int, series:List[torch.Tensor]):
    ax.clear()
    data = series[i]
    ax.scatter(data[:, 0], data[:, 1], s=15, alpha=0.5)
    ax.set_axis_off()

animate_forward = partial(animate, series=x_series)

anim = FuncAnimation(fig, animate_forward, frames=len(x_series),
                    interval=250)  # 500ms between frames

HTML(anim.to_jshtml())

For training, we would like to have diversity in the training batches, meaning different samples with different timesteps.  
We saw in class that it is to directly noise the data for a given timestep without having to go through the forward pass for all the timesteps.  
$$x_t = \sqrt{\bar{\alpha_t}}x_{0} + \sqrt{1-\bar{\alpha_t}}\epsilon_t$$
Complete the following function to sample the data for a given timestep and verify that it seems correct by plotting the data for t in [0, 6, 12, 25, 50]

In [None]:
def sample_x_t(x_0:torch.Tensor, t:int, alpha_bar:torch.Tensor, eps:torch.Tensor) -> torch.Tensor:
    """
    Takes the initial data, the alphas and the number of timesteps and returns a noisy version of the data for a given timestep
    args:
        x_0: the initial data
        alphas: the alphas of the noise schedule
        T: the number of timesteps
    returns:
        x_t: a noisy version of the data for a given timestep
    """
    x_t = alpha_bar[t,None].sqrt()*x_0 + (1.-alpha_bar[t,None]).sqrt()*eps
    return x_t


figure = plt.figure(figsize=(20, 4))
for i, t in enumerate([0, 6, 12, 25, 50]):
    dataset, time_step = x_series[t], t
    figure.add_subplot(1,5,i+1)
    plt.title(time_step)
    plt.axis("off")
    eps = torch.randn_like(x_0)
    x_t = sample_x_t(x_0, t,alpha_bar, eps)
    plt.scatter(x_t[:,0],x_t[:,1],s=15,alpha=0.5)

## Training
We will now train a denoising model to learn the reverse process.  
First, we need to create a dataset with our data.  
### Dataset
We now define a torch dataset to load our data.  We then split the data into a train and test set and create dataloaders for each.  

In [None]:
from torch.utils.data import DataLoader, Dataset, random_split

class DinoDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx].float()

dataset = DinoDataset(x_0)

# Train/Test Split
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


### Model:
Now that the dataset and dataloaders are created, we can define the model.  
We will use a simple MLP with 5 layers of 64 neurons each and a final layer to output the predicted noise.
We will use the GELU activation function between each layer.  
Remember that the output of the model will be the predicted noise which has the same dimension as the input.
Complete the following class to define the model.  
Since we will be training the model on the data for different timesteps, we will need to pass the timestep as an additional input to the model.  
We will do this by concatenating the timestep to the input data.  

In [None]:
class Denoisier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3,64),
            torch.nn.GELU(),
            torch.nn.Linear(64,64),
            torch.nn.GELU(),
            torch.nn.Linear(64,64),
            torch.nn.GELU(),
            torch.nn.Linear(64,64),
            torch.nn.GELU(),
            torch.nn.Linear(64,2)
        )
    def forward(self, x:torch.Tensor, t:torch.Tensor) -> torch.Tensor:
        # Concatenate the timestep to the input data
        x = torch.cat((x,t.reshape(-1,1)),dim=1)
        return self.layers(x)

### Training loop:
At this point, we have all the components to train the model.  
The training loop of a denoising model is actually quite simple.  
Look at the training algorithm from the paper and implement the training loop.  
![training_loop](images/training.png)


In [None]:
from tqdm.notebook import tqdm # you can use tqdm to display a progress bar

def train(model:torch.nn.Module, train_dataloader:DataLoader, optimizer:torch.optim.Optimizer, alpha_bar:torch.Tensor, epochs:int=50, device:str='cpu'):
    # don't forget to move everything to the correct device
    progress_bar = tqdm(range(epochs),desc="Training")
    for epoch in range(epochs):
        total_loss = 0
        for x in train_dataloader:
            x = x.to(device)
            eps = torch.randn_like(x)
            t = torch.randint(T,(x.shape[0],), device=device)
            x_t = sample_x_t(x,t,alpha_bar.to(device),eps)
            eps_pred = model(x_t,t)
            loss = torch.nn.functional.mse_loss(eps_pred,eps)
            total_loss += loss.item()*x.shape[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        progress_bar.set_postfix({"epoch": epoch, "loss": total_loss/x.shape[0]})
        progress_bar.update()
    progress_bar.close()
    

### Training:
Now instantiate the model and optimizer (Adam with a learning rate of 1e-3) and train the model for 3000 epochs.  

In [None]:
model = Denoisier().to(device) # don't forget to move the model to the correct device
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train(model, train_loader, optimizer, alpha_bar, epochs=3000, device=device)

### Sampling:
Here is the moment of truth!
We will now sample from the model and see if we did manage to learn the reverse process.  
Remember that the sampling algorithm is the following:  
![](images/sampling.png)  
Use it to complete the following function.  Once this is done, use it to sample 1000 points and plot them.  

In [None]:
@torch.no_grad()
def sample(num_samples:int, model:torch.nn.Module, alpha:torch.Tensor, alpha_bar:torch.Tensor, T:int=200, device:str='cpu') -> List[torch.Tensor]:
    """
    Takes the model, the alphas and the number of timesteps and returns a list of the sampled data for each timestep
    args:
        model: the denoising model
        alphas: the alphas of the noise schedule
        T: the number of timesteps
    returns:
        x_series: a list of the sampled data for each timestep
    """
    alpha = alpha.to(device)
    alpha_bar = alpha_bar.to(device)
    x_series = []
    xt = torch.randn((num_samples,2)).to(device)
    for t in reversed(range(T)):
        t_batch = torch.full((num_samples,), t).to(device)
        noise_pred = model(xt, t_batch)
        mu_hat_t = (xt - (1-alpha[t,None])/(1-alpha_bar[t,None]).sqrt()*noise_pred)/(alpha[t,None]).sqrt()

        z = torch.randn_like(xt).to(device)
        sigma = (1.-alpha[t]).sqrt()
        xt = mu_hat_t + sigma*z
        x_series.append(xt.clone().detach().to('cpu'))  # move the data to the cpu before appending
    return x_series

steps = sample(1000, model, alphas, alpha_bar, T=T, device=device)
plt.figure(figsize=[6, 6])
plt.scatter(steps[-1][:,0],steps[-1][:,1],s=15,alpha=0.5)
plt.axis('off')

We are getting there! But we can do better. Even if the model would probably have converged if we trained it for longer, for now let's try to train faster and better.  
### Time embedding:
We will now try to improve the model by adding a time embedding.  
This is a common technique in diffusion models to help the model learn the temporal aspect of the data.  
We will use a simple sinusoidal position embedding.  

In [None]:
import math

class SinusoidalPositionEmbeddings(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings



Now, complete the following class to define the model with the time embedding and train it for 2000 epochs.  

In [None]:
class DenoisierWithTimeEmbedding(torch.nn.Module):
    def __init__(self, t_emb_dim:int=32, device:str='cpu'):
        super().__init__()
        self.time_embedder = SinusoidalPositionEmbeddings(t_emb_dim)
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(2 + t_emb_dim,64),
            torch.nn.GELU(),
            torch.nn.Linear(64,64),
            torch.nn.GELU(),
            torch.nn.Linear(64,64),
            torch.nn.GELU(),
            torch.nn.Linear(64,64),
            torch.nn.GELU(),
            torch.nn.Linear(64,2)
        )

    def forward(self, x, t):
        t_emb = self.time_embedder(t)
        x = torch.cat((x,t_emb), dim=1)
        return self.layers(x)

model = DenoisierWithTimeEmbedding().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
train(model, train_loader, optimizer, alpha_bar, epochs=3000, device=device)

Sample 1000 new points and plot them.

In [None]:
steps = sample(1000, model, alphas, alpha_bar, T=T, device=device)
plt.figure(figsize=[6, 6])
plt.scatter(steps[-1][:,0],steps[-1][:,1],s=15,alpha=0.5)
plt.axis('off')

Using the previous code, make a small animation to visualize the sampling process.

In [None]:
animate_backward = partial(animate, series=steps)

anim = FuncAnimation(fig, animate_forward, frames=len(x_series),
                    interval=250)  # 500ms between frames


HTML(anim.to_jshtml())

## Image models
In the previous part, we have seen how to train an unconditional diffusion model to generate samples from a target distribution.  
We did it on a 2d example, but in practice, diffusion models are often used to generate more structured data like images or audio.  
In this part, we will see how to use a pretrained text2image model to generate images from text.  
We will use the [diffusers](https://github.com/huggingface/diffusers) library to load a pretrained model and generate images of faces.  

In [None]:
from diffusers import UNet2DModel

repo_id = "google/ddpm-celebahq-256" # "google/ddpm-church-256"
model = UNet2DModel.from_pretrained(repo_id)

The architecture of the neural network, referred to as **model**, commonly follows the UNet architecture as proposed in [this paper](https://arxiv.org/abs/1505.04597) and improved upon in the Pixel++ paper.

![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/unet-model.png)

We can start by loading the model and checking its configuration in particular to know what is its input shape.

In [None]:
model.config

Generate a random input for the model of shape (1, 3, 256, 256)

In [None]:
noisy_sample = torch.randn(...)
noisy_sample.shape

The following code does a noise prediction at for a given timestep and a given noisy image.

In [None]:
with torch.no_grad():
    noisy_residual = model(sample=noisy_sample, timestep=2).sample
noisy_residual.shape

diffusers pretrained models come with a scheduler, responsible for the noise schedule.  
Here is the corresponding scheduler

In [None]:
from diffusers import DDPMScheduler

scheduler = DDPMScheduler.from_config(repo_id)
scheduler.config

Knowing that $\bar{\alpha_t}$ is the cumulative product of the alphas where $\alpha = 1 - \beta$, we can compute the $\bar{\alpha_t}$ and plot them.  

In [None]:
T=1000
beta_min=0.0001
beta_max=0.02
betas = torch.linspace(beta_min, beta_max, T)
alphas = 1. - betas
alpha_bar = torch.cumprod(alphas, dim=0)
plt.figure(figsize=[6, 6])
plt.plot(torch.arange(T), alpha_bar)
plt.xlabel('Timestep')
plt.ylim(0,1.05)

Let's visualize the noise schedule applied to an image.  

In [None]:
!wget https://efrosgans.eecs.berkeley.edu/SwappingAutoencoder/results_for_paper_with_new_ffhq2/ffhq/input_structure/00001__000.png -O image.jpg

In [None]:
from PIL import Image
from torchvision import transforms
x_0 = Image.open("image.jpg")

mean = [ 0.485, 0.456, 0.406 ]
std = [ 0.229, 0.224, 0.225 ]
normalize = transforms.Normalize(mean, std)
inv_normalize = transforms.Normalize(
   mean= [-m/s for m, s in zip(mean, std)],
   std= [1/s for s in std]
)

transform = transforms.Compose([transforms.Resize((256, 256)),
                                transforms.ToTensor(),
                                normalize])
x_0 = transform(x_0)

def tensor_to_img(img):
    if len(img.shape) == 4:
        img = img.squeeze(0)
    img = inv_normalize(img)
    img = img.permute(1, 2, 0)
    return img

img = tensor_to_img(x_0)
plt.imshow(img)
plt.axis('off')
plt.show()

We need to modify a little bit the sampling function to make it compatible with image inputs.  

In [None]:
def sample_x_t(x_0, t, alpha_bar, eps):
    return alpha_bar[t, None, None, None].sqrt()*x_0 + (1.-alpha_bar[t, None, None , None]).sqrt()*eps

eps = torch.randn_like(x_0)
t = 500
x_t = sample_x_t(x_0, t, alpha_bar, eps)
img = tensor_to_img(x_t)
plt.imshow(img)
plt.axis('off')
plt.show()

Look at the mean and std of the image of a noisysample at time step 999. What would you expect and is it what you observe?

In [None]:
...

Let's look at the evolution of the image at several timestep.

In [None]:
plt.figure(figsize=[12, 6])
for i, t in enumerate([0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 999]):
    eps = torch.randn_like(x_0)
    x_t = sample_x_t(x_0, t, alpha_bar, eps)
    plt.subplot(1, 11, i+1)
    plt.imshow(tensor_to_img(x_t))
    plt.axis('off')
    #tight layout
    plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torch

fig, ax = plt.subplots(figsize=(8, 8))
ax.axis('off')

time_steps = [0, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 999]
all_frames = []

for t in time_steps:
    eps = torch.randn_like(x_0)
    x_t = sample_x_t(x_0, t, alpha_bar, eps)
    noisy_img = inv_normalize(x_t)
    noisy_img = noisy_img.permute(1, 2, 0)
    all_frames.append(noisy_img.detach().cpu().numpy())

def update(frame):
    ax.clear()
    ax.imshow(all_frames[frame])
    ax.set_title(f't = {time_steps[frame]}')
    ax.axis('off')
    return ax,

anim = animation.FuncAnimation(
    fig,
    update,
    frames=len(time_steps),
    interval=200,
    blit=False,
    repeat=True,
    repeat_delay=1000
)

plt.close()
HTML(anim.to_jshtml())

Let's now try to generate an image from a text prompt given the noise schedule previously computed and our previous sampling function.  
Let's define a new sampling function which is compatible with `diffusers` unet. No need here to store the intermediate steps.  

In [None]:
from tqdm.notebook import tqdm
import torch

model = model.to(device)
@torch.no_grad()
def sample(num_samples, model, alpha, alpha_bar, T=200, device='cuda'):
    alpha = alpha.to(device)
    alpha_bar = alpha_bar.to(device)
    xt = ... .to(device)
    for t in ...:
        t_batch = ... .to(device)
        noise_pred = model(xt, t_batch).sample
        mu_hat_t = (xt - (1-alpha[t,None])/(1-alpha_bar[t,None]).sqrt()*noise_pred)/(alpha[t,None]).sqrt()

        z = ... .to(device)
        sigma = ...
        xt = ...
    return xt

x_0 = sample(1, model, alphas, alpha_bar, T=1000)

Plot the generated image.   

In [None]:
img = tensor_to_img(x_0[0].cpu())
plt.imshow(img)
plt.axis('off')
plt.show()

This is it for this practical session.  I hope you gained more intuition about diffusion models.  
In the next session we will work with conditional diffusion models. If you finish early, you can try to run some text2image models using Fooocus.

In [None]:
del model
torch.cuda.empty_cache()

In [None]:
!pip install pygit2==1.15.1
%cd /content
!git clone https://github.com/lllyasviel/Fooocus.git
%cd /content/Fooocus
!python entry_with_update.py --share --always-high-vram