# Source: 
[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm.notebook import tqdm
from torch.distributions.laplace import Laplace
from model import DDPM
from model import MyTinyUNet
from etl import show_images
from etl import generate_image

In [None]:
import os

In [None]:
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

The `MyTinyUNet` is a versy small implementation of a convolutional UNet where a time embedding has been added at each step. To make things simple, we need to have an image of size $s\times s$ with $s$ divisible by 8 (this is why we will increase the size of the MNIST dataset from $28\times 28$ to $32 \times 32$).

In [None]:
bs = 3
x = torch.randn(bs,1,32,32)
n_steps=1000
timesteps = torch.randint(0, n_steps, (bs,)).long()
unet = MyTinyUNet(in_c =1, out_c =1, size=32)

In [None]:
y = unet(x,timesteps)
y.shape

Laplace distribution pdf: $f(x|\mu , b) = \frac{1}{2b} \exp(-\frac{|x-\mu|}{b})$

Laplace negative log likelihood: $NLL=\frac{1}{N}​\sum_{i=1}^{n}​log(2b)+\frac{|x_i-\mu|}{b}$ where b is the scale of Laplace

In [None]:
class LaplaceNLL(nn.Module):
    def __init__(self, scale = 1.0):
        super(LaplaceNLL, self).__init__()
        self.scale = scale
    
    def forward(self, input, target):
        log = torch.log(torch.tensor([2*self.scale])).to(device)
        nll_loss = torch.sum(log + torch.abs(input-target)/self.scale, dim=-1)
        nll_loss = torch.mean(nll_loss)
        return nll_loss

In [None]:
def generate_noise(noise_distribution, batch, params=None):
    noise_distribution_allowed = {"Gaussian", "Laplace", "S&P"}
    if noise_distribution not in noise_distribution_allowed:
        raise Exception(f"'noise_distribution' not of value {noise_distribution_allowed}")
    
    if noise_distribution == "Gaussian":
        # normal gaussian, a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 
        noise = torch.randn(batch.shape).to(device)
        
    elif noise_distribution == "Laplace":
        if params:
            loc = param['loc']
            scale = param['scale']
        else:
            # default laplace with loc 0 and scale 1
            loc = 0.0
            scale = 1.0
        noise = Laplace(torch.tensor([loc]), torch.tensor([scale])).expand(batch.shape).sample().to(device)
    
    elif noise_distribution == "S&P":
        salt_mask = torch.randn(batch.shape) < 0.01
        pepper_mask = (torch.randn(batch.shape) < 0.01) & ~salt_mask # so that we select spot that salt hasn't been added to
        noise = (salt_mask.float() + pepper_mask.float()*-1).to(device)
        
    return noise

In [None]:
def calc_loss(loss_f, pred, actual):
    
    if loss_f == "LaplaceNLL":
        criterion = LaplaceNLL()
    elif loss_f == "L1":
        criterion = nn.L1Loss()
    elif loss_f == "MSE":
        criterion = nn.MSELoss()
    
    loss = criterion(pred, actual)
        
    return loss

In [None]:
def training_loop(model, dataloader, optimizer, num_epochs, num_timesteps, device=device, noise="Gaussian", loss_f="MSE"):
    """Training loop for DDPM"""

    global_step = 0
    losses = []
   
    for epoch in range(num_epochs):
        model.train() # what does this do? train() is built-in?
        progress_bar = tqdm(total=len(dataloader))
        progress_bar.set_description(f"Epoch {epoch}")
        for step, batch in enumerate(dataloader):
            # put batch to device (gpu or cpu) to leverage computational resources
            batch = batch[0].to(device) # batch.shape = torch.Size([4096, 1, 32, 32])
            
            # create noise
            actual_noise = generate_noise(noise, batch)
            
            # Generates random timesteps for each sample in the batch. 
            # These timesteps determine at which diffusion step the noise is added. 
            # The num_timesteps parameter specifies the total number of diffusion steps.
            timesteps = torch.randint(0, num_timesteps, (batch.shape[0], )).long().to(device)
            
            # returns x_{t+1}, a noisy image
            noisy = model.add_noise(batch, actual_noise, timesteps)
            
            # gives the noisy image and returns a prediction of the noise added
            noise_pred = model.reverse(noisy, timesteps) # torch.Size([4096, 1, 32, 32])
            
            loss = calc_loss(loss_f, noise_pred, actual_noise)
            
            # reset the gradient to zero before computing the new gradient in the backward pass
            optimizer.zero_grad()
            
            # Computes the gradients of the loss with respect to the model parameters. 
            # These gradients are used to update the model weights during optimization.
            loss.backward()
            
            # updates the parameters
            optimizer.step()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "step": global_step}
            losses.append(loss.detach().item())
            progress_bar.set_postfix(**logs)
            global_step += 1
        
        progress_bar.close()

In [None]:
root_dir = './data/'
transform01 = torchvision.transforms.Compose([
        torchvision.transforms.Resize(32),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5), (0.5))
    ])
dataset = torchvision.datasets.MNIST(root=root_dir, train=True, transform=transform01, download=True)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4096, shuffle=True, num_workers=10)

In [None]:
dataloader

# Running DDPM

In [None]:
learning_rate = 1e-3
num_epochs = 50
num_timesteps = 50
network = MyTinyUNet()
network = network.to(device)
model = DDPM(network, num_timesteps, beta_start=0.0001, beta_end=0.02, device=device)
optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)      

In [None]:
training_loop(model, dataloader, optimizer, num_epochs, num_timesteps, device=device, loss_f="L1", noise="Laplace")  

In [None]:
generated, generated_mid = generate_image(model, 100, 1, 32, device)

In [None]:
# Note that generated and generated_mid are list object
len(generated), len(generated_mid)

In [None]:
show_images(generated_mid, "Mid result")
show_images(generated, "Final result")

In [None]:
show_images(bn, "origin")

# Aside

In [None]:
for b in dataloader:
    batch = b[0]
    what_is_b = b # b[0] is the image tensor, b[1] is the label
    break

print(batch.shape)
bn = [b for b in batch[:100]] # taking first 100 images from the pool of images - total 4096 images

### Testing salt & pepper noise

In [None]:
plt.imshow(bn[0].permute(1,2,0).numpy())

In [None]:
salt_mask = torch.rand_like(bn[0]) < 0.01
pepper_mask = (torch.rand_like(bn[0]) < 0.01) & ~salt_mask # so that we select spot that salt hasn't been added to

In [None]:
salt_mask_num = salt_mask.float()
pepper_mask_num = pepper_mask.float()*-1

In [None]:
salt_mask_num + pepper_mask_num

In [None]:
test = bn[0].clone().detach()

In [None]:
test[salt_mask] = 1
test[pepper_mask] = -1

In [None]:
plt.imshow(test.permute(1,2,0).numpy(), cmap="gray")

#### Testing LaplaceNLL class

In [None]:
laplace_noise1 = Laplace(torch.tensor([0.0]), torch.tensor([2.0])).expand(batch.shape).sample().to(device)
laplace_noise2 = Laplace(torch.tensor([0.0]), torch.tensor([2.0])).expand(batch.shape).sample().to(device)

In [None]:
torch.log(torch.tensor([100]))

In [None]:
LaplaceNLL()(laplace_noise1, laplace_noise2)

In [None]:
torch.nn.L1Loss()(laplace_noise1, laplace_noise2)

You can check that all the parameters of the UNet `network` are indeed parameters of the DDPM `model` like this:

In [None]:
# Note how the first size has changed to the new timesteps defined above, which is great
for n, p in model.named_parameters():
    print(n, p.shape)

In [None]:
# To check memory usage
!nvidia-smi

In [None]:
# looks like repeated timesteps are allowed?
timesteps = torch.randint(0, num_timesteps, (batch.shape[0], )).long().to(device) 
timesteps.shape

#### Why do we allow repeated timesteps?

Allowing repeated timesteps in the same batch is a common practice in various machine learning scenarios. It introduces additional stochasticity into the training process, which can help the model generalize better to different data patterns. Each sample in the batch effectively experiences a different diffusion process, adding diversity to the training data and encouraging the model to learn a more robust representation of the underlying data distribution.

In the context of the Denoising Diffusion Probabilistic Model (DDPM), allowing repeated timesteps helps the model learn to handle different diffusion steps for different samples in the same batch, simulating the stochastic nature of the diffusion process in real-world data.

### Random stuffs

In [None]:
a = torch.arange(4.)
b = a.reshape(-1, 1, 1, 1)
b, b.shape

In [None]:
a = torch.arange(4.)
b = a.reshape(-1, 2, 2, 1)
b, b.shape

In [None]:
a = torch.arange(4.)
b = torch.arange(4.)
((a+1)/(b+1)).shape

The commands below are here to help you and to test your code.

In [None]:
num_timesteps = 1000
betas = torch.linspace(0.0001, 0.02, num_timesteps, dtype=torch.float32).to(device)

In [None]:
timesteps

In [None]:
betas.shape

In [None]:
betas[timesteps]

In [None]:
betas[10]

In [None]:
betas[timesteps].reshape(-1,1,1,1).shape

In [None]:
network = MyTinyUNet(in_c =1, out_c =1, size=32)
model = DDPM(network, num_timesteps, beta_start=0.0001, beta_end=0.02, device=device)

In [None]:
bs = 5
x = torch.randn(bs,1,32,32).to(device)
timesteps = 10*torch.ones(bs,).long().long().to(device)

In [None]:
x, timesteps

In [None]:
timesteps.shape

In [None]:
y = model.add_noise(x,x,timesteps)
y.shape

In [None]:
y = model.step(x,timesteps[0],x)
y.shape

In [None]:
laplace_noise = Laplace(torch.tensor([0.0]), torch.tensor([1.0])).expand(torch.Size([4096, 1, 32, 32]))
laplace_noise.sample()

In [None]:
noise = torch.randn(torch.Size([4096, 1, 32, 32]))
noise