## Diffusion Models form scratch in PyTorch

No one have probably missed the recent boom and hype around Diffusion Models and their amazing results. Creating the most unimaginable images when for example given a text description. A couple of weeks ago Stability AI open sourced the powerful StableDiffusion model whom anyone can download and run. It is also fast enough to run locally on your own machine if you have a decent GPU.

Now I have to admit that I have just barely scratched the surface of diffusion models and their details, I blame it on the impossible task of keeping up with high pace in the current machine learning community. So this post is about implementing a diffusion model in PyTorch, train it on a dataset and at least touch on the details that make these models work. It will not be a dive to the depth of ocean, in those cases I will link to good resources that I have found useful for understanding the math and the derivation behind the equations. Lets start.

### Intro and High-level Intuition
So the basic idea of diffusion models is to learn a model that can create samples from some distribution, for example natural images, from just noise. An analogy I like to have in mind is that the models are similar to a human forming a lump of clay into for example a cup, we start with something containing no information and slowly form something useful. 

The setup when training a diffusion model have two phases, a forward process, in which we over a set of timesteps $0:T$ add noise to a sample $\mathbf{x}_0$ so we end up with the sample at time $T$ only consisting of isotropic (uniform in all dimensions) gaussian noise with mean zero and variance one. Then we have a backward process where we do the opposite, go from the noise $\mathbf{x}_T$ to the initial sample $\mathbf{x}_0$. It is this backward process we want to learn. The forward process is usually fixed and is just consisting of a noise scheduler which tells how much noise to add at each timestep. But, lets start with the forward process and take it from there.

#### Forward Diffusion Process

The noise is added by a variance schedule, $\{\beta\in(0,1)\}_{t=1}^{T}$, we will use a linear one. Lets create a function for that.

In [1]:
import torch

def linear_beta_schedule(timesteps: int, start: float = 0.0001, end: float = 0.02):
    """Creates a linear noise schedule"""
    return torch.linspace(start, end, timesteps)

  from .autonotebook import tqdm as notebook_tqdm


Adding gaussian noise from one time step to the next is then defined by

$$
q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) = \mathcal{N}\left(\mathbf{x}_t ; \sqrt{1 - \beta_t}\mathbf{x}_{t-1} , \beta_t\mathbf{I}\right)
$$

And the complete forward process can be written as

$$
q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0}) = \prod_{t=1}^{T}q(\mathbf{x}_t \vert \mathbf{x}_{t-1})
$$

If we then define a variable $\alpha_t \doteq 1 - \beta_t$ and use the fact that sampling from any normal distribution can be written as the mean plus the variance times $\epsilon$, where $\epsilon \sim \mathcal{N}(0, \mathbf{I})$ (also known as the reparameterization trick) we can show that

$$
\mathbf{x_t} = \sqrt{\bar{\alpha_{t}}} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha_{t}}}\epsilon
\quad \Rightarrow \quad
q(\mathbf{x}_t \vert \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t;\sqrt{\bar{\alpha_{t}}} \mathbf{x}_0 , (1 - \bar{\alpha_{t}})\mathbf{I})
$$

Where $\bar{\alpha_t} = \prod_{i=0}^{t}\alpha_i$. So we see that we can sample $\mathbf{x_t}$ at any time $t$ given just $\mathbf{x}_0$, i.e we can sample in closed form. This is because a sums of gaussians is also gaussian. See [Lilian's excellent post](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/) for derivation and details.

**Our forward diffusion process can then be implemented as**

In [2]:
def get_val_from_t(vals, t, x_shape):
    """
    Helper to get specific t's of a passed list of values 'vals' and returning them with the same dimensions as 'x_shape'
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    
    # reshape and add correct dimension
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion(x_0, t, device="cpu"):
    """
    Takes an image 'x_0' and a time step 't' and returns the noisy version of the image at time 't'
    """
    noise = torch.randn_like(x_0).to(device)
    
    sqrt_alphas_cumprod_t = get_val_from_t(sqrt_alphas_cumprod, t, x_0.shape).to(device)
    sqrt_one_minus_alphas_cumprod_t = get_val_from_t(sqrt_one_minus_alphas_cumprod, t, x_0.shape).to(device)

    return (
        sqrt_alphas_cumprod_t * x_0.to(device) + sqrt_one_minus_alphas_cumprod_t * noise,
        noise
        )

Note that we are using some variables here that we have not yet defined, these will be pre-calculated later. Otherwise that's it for the forward process! Now that we have it in place let us test it with some data.

### Dataset

We are going to use the StanfordCars dataset and for our purpose we only need to know it is a lot of images of cars. Let us download it ( if you running this on your own machine keep in mind it is $\sim$ 1GB)

In [3]:
import torchvision

dataset = torchvision.datasets.StanfordCars(root='data', download=True)

To get a feeling of the images we plot some examples

In [4]:
import matplotlib.pyplot as plt

def show_images(dataset, num_samples=16, cols=8):
    """Plot some examples"""
    plt.figure(figsize=(30, 10))
    for i, img in enumerate(dataset):
        if i == num_samples:
            break
        plt.subplot(num_samples // cols + 1, cols, i + 1)
        plt.imshow(img[0])

show_images(dataset)
plt.show()

#### Transform the data

The images are in PIL format and we need them as torch tensors. Here we transform them and also put them into a DataLoader that will be used at training.

In [None]:
from torchvision import transforms 
from torch.utils.data import DataLoader, Subset

IMG_SIZE = 32
BATCH_SIZE = 128

dataset_transformed = torchvision.datasets.StanfordCars(root='data', transform=transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # convert values to be between [-1.0,  1.0]
    transforms.Lambda(lambda x: (x * 2) - 1)
]))

dataset_transformed = Subset(dataset_transformed, range(0, 2000))

# dimensions will be (batch, C, H, W)
dataloader = DataLoader(dataset_transformed, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

Now we get to those pre-calculated values I mentioned before. Some values will be the same for all $t$ like $\sqrt{\bar{\alpha_{t}}}$ and $\sqrt{1 - \bar{\alpha_{t}}}$. So instead of calculate them every time we want to get a sample $\mathbf{x}_t$ we do it once here.

In [None]:
import torch.nn.functional as F

# Number of time steps in the diffusion process
T = 300

# Create the betas
betas = linear_beta_schedule(timesteps=T)

# Pre-calculations
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_one_over_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

Now we are ready to test our forward diffusion process

In [None]:
import numpy as np

def tensor_to_pil(image):
    """Reverse the transformation and plot the image"""
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    return reverse_transforms(image)

# take an image
image_0 = next(iter(dataloader))[0][0]

# since plotting 'T' number of images is a bit much we decrease it
num_images = 10
stepsize = T // (num_images - 1)

fig, ax = plt.subplots(1, num_images, figsize=(20, 20), constrained_layout=True)
for ind in range(0, num_images):

    t = stepsize * ind
    tt = torch.Tensor([t]).type(torch.int64)

    image_t, _ = forward_diffusion(image_0, tt)
    ax[ind].axis("off")
    ax[ind].set_title(f"t = {t}")
    ax[ind].imshow(tensor_to_pil(image_t))

Seems to work quite well. Just to get indication that the last image $\mathbf{x}_T$ indeed is similar to noise drawn from $\sim \mathcal{N}(0, \mathbf{I})$ we can check to mean and standard deviation across the image

In [None]:
std, mean = torch.std_mean(image_t)
print(f"Mean: {mean:.2f} Std: {std:.2f}")

Quite close, you an try and change $T$ and as you increase it we will get closer to the wanted distribution $\mathcal{N}(0, \mathbf{I})$

### Reverse diffusion process

So how do we go back? From noise to the original image? It might at first feel trivial since it was so easy to arrive at the noisy version. But remember, at each step when adding noise we lose information in a statistical fashion. In other words, we create an extremely large amount of possible ways to go back. In fact, to reverse the process and get back the original distribution we would have to integrate (marginalize) over all possible ways we could arrive at the original image $\mathbf{x_0}$ including all the latent (non observed) variables on the way. This is understandably intractable in the case of natural images. So the only way we could manage this is by approximation!

Here comes neural networks, once again, to the rescue. So, the idea is to train a neural network parameterized by $\theta$ to approximate

$$ q(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}) \approx p_{\theta}(\mathbf{x}_{t-1} \vert \mathbf{x}_{t})$$

The first obvious way to do this would be to estimate the mean and covariance for each denoising step.

$$p_{\theta}(\mathbf{x}_{t-1} \vert \mathbf{x}_t) = \mathcal{N}(\mu_{\theta}(\mathbf{x}_t, t), \Sigma_{\theta}(\mathbf{x}_t, t))$$

However the authors of the [Denoising Diffusion Probabilistic Models](https://arxiv.org/pdf/2006.11239v2.pdf) paper explored an alterative parameterization. Instead of estimating $ \mu_{\theta}(\mathbf{x}_t, t)$ we can make the network estimate the noise added at each time step, $\epsilon_{\theta}(\mathbf{x}_t, t)$. This since it can be shown that (jump into the paper for the derivation)

$$ \mu_{\theta}(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}} \epsilon_{\theta}(\mathbf{x}_t, t) \right)$$

This makes it possible to simplify the variational lower bound while maintaining performance, sounds a bit scary but let's leave it there for now.

### Training objective

So now there is only one thing missing, how do we know that our network's approximation of $q(\mathbf{x}_{t-1} \vert \mathbf{x}_{t})$ is actually good? What metric can we use to measure it? The answer to this lies within a domain called Variational Inference where the core is to say (infer) something about some unobserved (latent) variable by optimizing for a function (in our case $p_{\theta}(\mathbf{x}_{t-1} \vert \mathbf{x}_{t})$ ). More over we are going to use something called Varational Lower Bound (VLB) or Evidence Lower Bound (ELBO), the derivation is a bit cumbersome, but lets try and get the gist of it. We are ging to start from the Kullback-Leibler Divergence which is a measure of how similar two distrubtions are.

$$\mathbf{D}_{KL}(P \parallel Q) = \sum_{x\in\mathcal{X}}P(x)\log\left(\frac{P(x)}{Q(x)}\right)$$

This sounds like the perfect metric right, the similarity between two distributions, lets plug in our distributions, switch to the continuous realm and include the complete sequence

$$\mathbf{D}_{KL}(q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0}) \parallel p_{\theta}(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})) = \int q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})\log\left(\frac{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}{p_{\theta}(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\right) d\mathbf{x}_{1:T}$$

Here we say that given the initial image $\mathbf{x}_0$ we want the joint probability of the latent variables $\mathbf{x}_{1:T}$ to be the same for both distributions, which is a property we want. Lets work with this expression a bit more

$$
\begin{aligned}
\mathbf{D}_{KL}(q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0}) \parallel p_{\theta}(\mathbf{x}_{1:T} \vert \mathbf{x}_{0}))
&=\int q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})\log\left(\frac{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}{p_{\theta}(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\right) d\mathbf{x}_{1:T} \\
&= \int q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})\log\left(\frac{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0}) p_{\theta}(\mathbf{x}_0)}{p_{\theta}(\mathbf{x}_{0:T})}\right) d\mathbf{x}_{1:T} & \scriptsize{\text{Use that } p(a \vert b) = p(a, b)p(b)} \\
&=\mathbb{E}_{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\Big[ \log\left(\frac{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0}) p_{\theta}(\mathbf{x}_0)}{p_{\theta}(\mathbf{x}_{0:T})}\right) \Big] & \scriptsize{\text{From the definition of expectation}} \\
&=\mathbb{E}_{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\Big[ \log\left(\frac{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}{p_{\theta}(\mathbf{x}_{0:T})}\right) \Big] + \mathbb{E}_{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\Big[\log{p_{\theta}(\mathbf{x}_0)}\Big] & \scriptsize{\text{Split expectation into two}} \\
&=\mathbb{E}_{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\Big[ \log\left(\frac{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}{p_{\theta}(\mathbf{x}_{0:T})}\right) \Big] + \log{p_{\theta}(\mathbf{x}_0)}& \scriptsize{\text{The last term does not depend of q so the expectation can be removed}} \\
&=-\mathbb{E}_{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\Big[ \log\left(\frac{p_{\theta}(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\right) \Big] + \log{p_{\theta}(\mathbf{x}_0)}& \scriptsize{\text{Flip the logarithm}} \\
\end{aligned}
$$

Look at the last term on the RHS $\log{p_{\theta}(\mathbf{x}_0)}$, it is the log-likelihood of the distribution producing real images, which is exactly what we would like to maximize! The first term on the RHS is what is called the **Variational Lower Bound** defined as

$$
\mathcal{L}_{VLB} = \mathbb{E}_{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\Big[ \log\left(\frac{p_{\theta}(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}\right) \Big]
$$

Let us write the equation bit less verbose

$$
\mathbf{D}_{KL} = \log{p_{\theta}} - \mathcal{L}_{VLB}
$$

But why is it called the **Variational Lower Bound** and why is it useful. Let us reason a bit about the terms, $\log{p_{\theta}}$ is always $\le0$ because the logarithm will be of values between 0 and 1. The $\mathbf{D}_{KL}$ is always $\ge0$ since it is a distance. So to make the equation valid $\mathcal{L}_{VLB}$ must be $\le0$. Add to this that in general $\mathbf{D}_{KL}$ is $>0$ otherwise we would have a perfect approximation of the target distribution, so this makes $\mathcal{L}_{VLB}$ have to be smaller than $\log{p_{\theta}}$, hence it is a lower limit or a **lower bound** of $\log{p_{\theta}}$. So it turns out that we can compute the lower bound and by maximize it we push up $\log{p_{\theta}}$ which is our goal. I really recommend watching [this video](https://www.youtube.com/watch?v=HxQ94L8n0vU) to get a deeper intuition about the **Variational Lower Bound**.
<br/>
<br/>
The last derivation of how we reach the final loss function I leave out, but you can once again see [Lilian's great post](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/), but after some massaging one can conclude that the $\mathcal{L}_{VLB}$ can be expanded into a sum of KL-divergence's and the only term we need to care about is

$$
\mathcal{L}_{VLB,t} = D_\text{KL}(q(\mathbf{x}_t \vert \mathbf{x}_{t+1}, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_t \vert\mathbf{x}_{t+1})) \text{ for }1 \leq t \leq T-1
$$

which can be shown equals to

$$
\mathcal{L}_{VLB,t} = \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}} \Big[\frac{ (1 - \alpha_t)^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \| \boldsymbol{\Sigma}_\theta \|^2_2} \|\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\|^2 \Big]
$$

Yes, your vision is correct, after all this we are almost back at the classical mean squared error. In fact [Ho et al. (2020)](https://arxiv.org/pdf/2006.11239v2.pdf) found that the model even works better if we drop the weighting term, so our final training objective becomes just

$$
\begin{aligned}
\mathcal{L}_{VLB,t}
&= \|\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\|^2 \\
&= \|\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_\theta(\sqrt{\bar{\alpha_{t}}} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha_{t}}}\epsilon, t)\|^2
\end{aligned}
$$

We can sample this objective for different $t$ given an initial image $\mathbf{x}_0$ and update our network with the gradients as usual. Lets implement this objective


In [None]:
import torch.nn.functional as F

def get_loss(model, x_0, t, device):
    """Calculate the loss for given image and t"""
    x_t, noise = forward_diffusion(x_0, t, device)
    noise_pred = model(x_t, t)
    return F.l1_loss(noise, noise_pred)

### Network Architecture

We will use almost the same architecture as [Denoising Diffusion Probabilistic Models](https://arxiv.org/pdf/2006.11239v2.pdf) and I have taken the great implementation from [Lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch) and simplified it a bit. On a high level the architecture is a U-net with baked in attention and Resnet blocks. 

In [None]:
from functools import partial
import math
from tokenize import group

import torch
from torch import nn
from einops import reduce, rearrange


def l2norm(t):
    return F.normalize(t, dim = -1)


class UNet(nn.Module):
    def __init__(self, img_channels: int, init_dim: int, time_emb_dim: int, num_res: int = 4):
        """Creates a UNet

        Args:
            in_channels (int): number of images channels
            init_dim (int): number of output channels in the first layer
            time_emb_dim (int): time dimension size
            num_res (int, optional): Number of resolutions
        """
        super().__init__()
        
        # initial conv
        self.init_conv = nn.Conv2d(img_channels, init_dim, kernel_size=7, padding=3)
        
        # create list of the different dimensions
        dims = [init_dim, *map(lambda m: init_dim * m, [2**res for res in range(0, num_res)])]

        # create convenient list of tuples with input and output channels for each resolution
        in_out_dims = list(zip(dims[:-1], dims[1:]))
        
        # time embedding block
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(init_dim),
            nn.Linear(init_dim, time_emb_dim),
            nn.GELU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        
        # downsample
        self.down_layers = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out_dims):
            is_last = ind >= num_res - 1
            
            self.down_layers.append(nn.ModuleList([
                ResNetBlock(dim_in, dim_in, time_emb_dim=time_emb_dim),
                ResNetBlock(dim_in, dim_in, time_emb_dim=time_emb_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1)
            ]))
        
        # middle block
        mid_dim = dims[-1]
        self.mid_block1 = ResNetBlock(mid_dim, mid_dim, time_emb_dim=time_emb_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = ResNetBlock(mid_dim, mid_dim, time_emb_dim=time_emb_dim)
        
        # upsample
        self.up_layers = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out_dims)):
            is_last = ind == num_res - 1
            
            self.up_layers.append(nn.ModuleList([
                ResNetBlock(dim_in + dim_out, dim_out, time_emb_dim=time_emb_dim),
                ResNetBlock(dim_in + dim_out, dim_out, time_emb_dim=time_emb_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1)
            ]))
        
        self.final_res_block = ResNetBlock(init_dim * 2, init_dim, time_emb_dim = time_emb_dim)
        self.final_conv = nn.Conv2d(init_dim, img_channels, 1)
    
    def forward(self, x, time):

        x = self.init_conv(x)

        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.down_layers:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)
        
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.up_layers:

            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim = 1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)


class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
        var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

class ResNetBlock(nn.Module):
    def __init__(self, dim_in, dim_out, time_emb_dim, groups = 8):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        )
        
        self.block1 = Block(dim_in, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim_in, dim_out, 1) if dim_in != dim_out else nn.Identity()
    
    def forward(self, x, t, time_emb = None):
        
        scale_shift = None
        if self.mlp is not None and time_emb is not None:
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)
        
        
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if scale_shift is not None:
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x
        

def Upsample(dim_in, dim_out):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim_in, dim_out, 3, padding = 1)
    )

def Downsample(dim_in, dim_out = None):
    return nn.Conv2d(dim_in, dim_out, 4, 2, 1)

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x
    
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * (var + eps).rsqrt() * self.g
    
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class LinearAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            LayerNorm(dim)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale
        v = v / (h * w)

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32, scale = 10):
        super().__init__()
        self.scale = scale
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q, k = map(l2norm, (q, k))

        sim = torch.einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)
        out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

Create a model and print number of parameters

In [None]:
model = UNet(img_channels=3, init_dim=64, time_emb_dim=32)
print(f"Num params: {sum(p.numel() for p in model.parameters()):,}")

34 million should be low enough to train in a reasonable amount of time

Finally we will create some helper functions to sample from our model as well as plot the samples for different $t$'s.

In [None]:
@torch.no_grad()
def sample_timestep(x, t, model, device):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_val_from_t(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_val_from_t(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_one_over_alphas_t = get_val_from_t(sqrt_one_over_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_one_over_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_val_from_t(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    
    noise = torch.randn_like(x)
    return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def sample_images(model, device, num_samples = 5, images_per_sample = 10):

    images = torch.empty((num_samples, num_images, 1, 3, IMG_SIZE, IMG_SIZE))

    stepsize = int(T/num_images)

    for sample in range(num_samples):
        # Sample noise
        img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)

        for i in range(0, T)[::-1]:
            t = torch.full((1,), i, device=device, dtype=torch.long)
            img = sample_timestep(img, t, model, device=device)

            if i % stepsize == 0:
                col = num_images - i // stepsize - 1
                images[sample, col] = img.detach().cpu()
    
    return images

def plot_sampled_images(images):
    num_samples, num_images, *_ = images.shape
    fig, ax = plt.subplots(num_samples, num_images, figsize=(20, 20), constrained_layout=True)

    for sample in range(num_samples):
        for image in range(num_images):
            ax[sample, image].axis("off")
            ax[sample, image].imshow(tensor_to_pil(images[sample, image]))
    
    fig.tight_layout()
    plt.show()


### Training

Now it is finally time to start the training, so we define a simple training loop using our dataloader. We track the loss in tensorboard and store the samples from our model along the way.

In [None]:
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid


device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
scaler = torch.cuda.amp.GradScaler()
optimizer = Adam(model.parameters(), lr=0.001)
writer = SummaryWriter()
global_step = 0

EPOCHS = 500 # Try more!

for epoch in range(EPOCHS):
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        # sample t's
        t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()

        loss = get_loss(model, batch[0], t, device=device)
        writer.add_scalar("Loss", loss, global_step=global_step)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        global_step += 1
  
    if epoch % 5 == 0:
      print(f"epoch: {epoch:4} | loss {loss.detach().cpu()}")
      images = sample_images(model=model, device=device)
      image_grid = make_grid(images.reshape(-1, 3, IMG_SIZE, IMG_SIZE), nrow=images.shape[1], normalize=True)
      writer.add_image('images', image_grid, global_step=epoch)
      #plot_sampled_images(images[0:2])
  
writer.flush()
writer.close()

# Results and Samples

# Conclusions

In [None]:
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

writer = SummaryWriter()

plt.imshow(tensor_to_pil(images[0]))
print(image_0.shape)

image_grid = make_grid(images[0:10], normalize=True)
writer.add_image("test_image", image_grid)
# writer.add_image("test_image", reverse_transforms(image_0))

In [None]:
reverse_transforms = transforms.Compose([
    transforms.Lambda(lambda t: (t + 1) / 2)
])

In [None]:
images = next(iter(dataloader))[0]

In [None]:
images.shape