# Diffusion models


<!-- [Click here to open this notebook in Colab](https://colab.research.google.com/github/williamgilpin/cphy/blob/main/talks/svd_decomp.ipynb) -->
Open this notebook in Google Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/williamgilpin/cphy/blob/main/talks/svd_decomp.ipynb)

We start by importing the necessary Python packages.
<!-- *This notebook created by William Gilpin. Consult the [course website](https://www.wgilpin.com/cphy) for all content and [GitHub repository](https://github.com/williamgilpin/cphy) for raw files and runnable online code.* --

In [1]:
import numpy as np

# Wipe all outputs from this notebook
from IPython.display import Image, clear_output, display
clear_output(True)

# Import local plotting functions and in-notebook display functions
import matplotlib.pyplot as plt
%matplotlib inline


In [2]:
# Minimal DDPM on MNIST (pure PyTorch, ~120 lines)
import math, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

# ----------------------- tiny UNet -----------------------
class Block(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_out, 3, padding=1), nn.GroupNorm(8, c_out), nn.SiLU(),
            nn.Conv2d(c_out, c_out, 3, padding=1), nn.GroupNorm(8, c_out), nn.SiLU(),
        )
    def forward(self, x): return self.net(x)

class TinyUNet(nn.Module):
    """
    Args:
        n_steps (int): Number of diffusion steps for sinusoidal time embedding.
    """
    def __init__(self, n_steps=200):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(64, 128), nn.SiLU(),
            nn.Linear(128, 128),
        )
        def tembed(t, dim=64):
            half = dim//2
            freqs = torch.exp(torch.linspace(0, math.log(10000), half, device=t.device))
            args = t[:,None]/n_steps
            emb = torch.cat([torch.sin(args*freqs), torch.cos(args*freqs)], dim=-1)
            return emb
        self.tembed = tembed

        self.inp  = nn.Conv2d(1, 32, 3, padding=1)
        self.b1   = Block(32, 64)
        self.down = nn.Conv2d(64, 64, 4, 2, 1)
        self.b2   = Block(64, 128)
        self.mid  = Block(128, 128)
        self.up   = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.b3   = Block(64+64, 64)
        self.out  = nn.Conv2d(64, 1, 1)

    def forward(self, x, t):
        """
        Args:
            x (Tensor): Noisy images (N,1,28,28).
            t (Tensor): Integer timesteps (N,).
        Returns:
            Tensor: Predicted noise ε_θ(x_t,t) with same shape as x.
        """
        temb = self.time_mlp(self.tembed(t))[:, :, None, None]
        h0 = F.silu(self.inp(x))
        h1 = self.b1(h0 + temb)
        h2 = self.down(h1)
        h3 = self.b2(h2 + temb)
        m  = self.mid(h3 + temb)
        u  = self.up(m)
        u  = self.b3(torch.cat([u, h1], 1) + temb)
        return self.out(u)

# ----------------------- diffusion utils -----------------------
class Diffusion:
    """
    Args:
        n_steps (int): Number of diffusion steps.
        beta_start (float): Start of linear beta schedule.
        beta_end (float): End of linear beta schedule.
        device (str): 'cuda' or 'cpu'.
    Attributes:
        betas (Tensor): β_t schedule.
        alphas_cumprod (Tensor): ∏_{s<=t} (1-β_s).
        alphas_cumprod_prev (Tensor): ∏_{s<=t-1} (1-β_s).
        sqrt_alphas_cumprod (Tensor)
        sqrt_one_minus_alphas_cumprod (Tensor)
        posterior_variance (Tensor): q(x_{t-1}|x_t,x_0) variance.
    """
    def __init__(self, n_steps=200, beta_start=1e-4, beta_end=0.02, device="cpu"):
        self.device = device
        self.n_steps = n_steps
        self.betas = torch.linspace(beta_start, beta_end, n_steps, device=device)
        alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), self.alphas_cumprod[:-1]])
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)

    def q_sample(self, x0, t, noise=None):
        """Forward diffusion q(x_t|x_0): add noise at step t."""
        if noise is None: noise = torch.randn_like(x0)
        s1 = self.index(self.sqrt_alphas_cumprod, t, x0.shape)
        s2 = self.index(self.sqrt_one_minus_alphas_cumprod, t, x0.shape)
        return s1 * x0 + s2 * noise

    def p_sample(self, model, x, t):
        """Single reverse step using ε-prediction parameterization."""
        betat = self.index(self.betas, t, x.shape)
        ac_t  = self.index(self.alphas_cumprod, t, x.shape)
        sqrt_one_minus_ac_t = self.index(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
        eps = model(x, t)
        x0_hat = (x - sqrt_one_minus_ac_t * eps) / torch.sqrt(ac_t)
        mean = (1/torch.sqrt(1 - betat))*(x - betat/torch.sqrt(1 - ac_t) * eps)
        var = self.index(self.posterior_variance, t, x.shape)
        if (t == 0).all(): return mean
        return mean + torch.sqrt(var) * torch.randn_like(x)

    def sample(self, model, n, shape):
        """Draw n samples by iterating t=T-1..0."""
        model.eval()
        x = torch.randn(n, *shape, device=self.device)
        for t in reversed(range(self.n_steps)):
            tt = torch.full((n,), t, device=self.device, dtype=torch.long)
            with torch.no_grad():
                x = self.p_sample(model, x, tt)
        return x

    @staticmethod
    def index(a, t, x_shape):
        return a.gather(-1, t).reshape(-1, *([1]*(len(x_shape)-1)))

# # ----------------------- training script -----------------------
# def main():
#     """
#     Minimal training loop. Produces samples in ./ddpm_samples.png.

#     Returns:
#         None
#     """
#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     T = 200
#     bs = 128
#     epochs = 1  # bump to 5–10 for better quality
#     lr = 2e-4

#     tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x*2-1)])  # [-1,1]
#     ds  = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
#     dl  = DataLoader(ds, batch_size=bs, shuffle=True, num_workers=2, drop_last=True)

#     net = TinyUNet(n_steps=T).to(device)
#     diff = Diffusion(n_steps=T, device=device)
#     opt = torch.optim.AdamW(net.parameters(), lr=lr)

#     for epoch in range(epochs):
#         for x,_ in dl:
#             x = x.to(device)
#             t = torch.randint(0, T, (x.size(0),), device=device).long()
#             noise = torch.randn_like(x)
#             x_t = diff.q_sample(x, t, noise)
#             pred = net(x_t, t)
#             loss = F.mse_loss(pred, noise)
#             opt.zero_grad(); loss.backward(); opt.step()
#         print(f"epoch {epoch+1} | loss={loss.item():.4f}")

#     with torch.no_grad():
#         imgs = diff.sample(net, n=64, shape=(1,28,28))
#         imgs = (imgs.clamp(-1,1)+1)/2
#     return imgs
# #         utils.save_image(imgs, "ddpm_samples.png", nrow=8)
# #     print("Wrote ddpm_samples.png")

# # if __name__ == "__main__":
# #     main()


In [None]:
T = 200 # number of diffusion steps
bs = 128 # batch size 
epochs = 1  # bump to 5–10 for better quality
lr = 2e-4 # learning rate
## If torch is able to find a GPU, use it. Otherwise, use the CPU.
device = "cuda" if torch.cuda.is_available() else "cpu"

ds  = datasets.MNIST(root="./", train=True, download=True, transform=tfm)
dl  = DataLoader(ds, batch_size=bs, shuffle=True, num_workers=2, drop_last=True)

net = TinyUNet(n_steps=T).to(device)
diff = Diffusion(n_steps=T, device=device)
opt = torch.optim.AdamW(net.parameters(), lr=lr)

for epoch in range(epochs):
    for x,_ in dl:
        x = x.to(device)
        t = torch.randint(0, T, (x.size(0),), device=device).long()
        noise = torch.randn_like(x)
        x_t = diff.q_sample(x, t, noise)
        pred = net(x_t, t)
        loss = F.mse_loss(pred, noise)
        opt.zero_grad(); loss.backward(); opt.step()
    print(f"epoch {epoch+1} | loss={loss.item():.4f}")

with torch.no_grad():
    imgs = diff.sample(net, n=64, shape=(1,28,28))
    imgs = (imgs.clamp(-1,1)+1)/2

100%|██████████| 9.91M/9.91M [00:00<00:00, 23.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 722kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 8.15MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.94MB/s]


PicklingError: Can't pickle <function <lambda> at 0x34fd898a0>: attribute lookup <lambda> on __main__ failed