# Diffusion models


<!-- [Click here to open this notebook in Colab](https://colab.research.google.com/github/williamgilpin/cphy/blob/main/talks/svd_decomp.ipynb) -->
Open this notebook in Google Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/williamgilpin/cphy/blob/main/talks/svd_decomp.ipynb)

We start by importing the necessary Python packages.
<!-- *This notebook created by William Gilpin. Consult the [course website](https://www.wgilpin.com/cphy) for all content and [GitHub repository](https://github.com/williamgilpin/cphy) for raw files and runnable online code.* --

In [2]:
import numpy as np

# Wipe all outputs from this notebook
from IPython.display import Image, clear_output, display
clear_output(True)

# Import local plotting functions and in-notebook display functions
import matplotlib.pyplot as plt
%matplotlib inline


In [3]:
# minimal_ddpm_mnist.py
import math, torch, torch.nn as nn, torch.nn.functional as F
from torchvision import datasets, transforms, utils

# ----------------------- diffusion utilities -----------------------


def sinusoidal_embedding(t, dim=64):
    """
    Args:
        t (torch.LongTensor): Timesteps (B,) in [0, T-1].
        dim (int): Embedding dimension (even).

    Returns:
        (torch.Tensor): Positional embeddings (B, dim).
    """
    half = dim // 2
    freqs = torch.exp(torch.linspace(math.log(1.0), math.log(10000.0), half, device=t.device))
    ang = t[:, None] * freqs[None]
    return torch.cat([torch.sin(ang), torch.cos(ang)], dim=1)

def q_sample(x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None):
    """
    Args:
        x0 (torch.Tensor): Clean images in [-1,1], shape (B,1,H,W).
        t (torch.LongTensor): Timesteps (B,).
        sqrt_alphas_cumprod (torch.Tensor): Precomputed coefficients (T,).
        sqrt_one_minus_alphas_cumprod (torch.Tensor): Precomputed coefficients (T,).
        noise (torch.Tensor, optional): Noise to add.

    Returns:
        (torch.Tensor): Noised x_t with same shape as x0.
    """
    if noise is None: noise = torch.randn_like(x0)
    sa = sqrt_alphas_cumprod[t].view(-1,1,1,1)
    som = sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
    return sa * x0 + som * noise

# ----------------------- tiny U-Net (pedagogical) -----------------------

class ResidualBlock(nn.Module):
    def __init__(self, ch, tdim):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, ch)
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.time = nn.Sequential(nn.SiLU(), nn.Linear(tdim, ch))
    def forward(self, x, t_emb):
        h = self.conv1(F.silu(self.norm1(x)))
        h = h + self.time(t_emb).unsqueeze(-1).unsqueeze(-1)
        h = self.conv2(F.silu(self.norm2(h)))
        return x + h

class TinyUNet(nn.Module):
    """
    A very small U-Net for 32x32 grayscale.

    Args:
        ch (int): Base channel width.
        tdim (int): Timestep embedding dim.
    """
    def __init__(self, ch=64, tdim=64):
        super().__init__()
        self.inp = nn.Conv2d(1, ch, 3, padding=1)
        self.rb1 = ResidualBlock(ch, tdim)
        self.down = nn.Conv2d(ch, ch*2, 4, stride=2, padding=1)
        self.rb2 = ResidualBlock(ch*2, tdim)
        self.mid = ResidualBlock(ch*2, tdim)
        self.up = nn.ConvTranspose2d(ch*2, ch, 4, stride=2, padding=1)
        self.rb3 = ResidualBlock(ch, tdim)
        self.out = nn.Conv2d(ch, 1, 3, padding=1)
        self.tproj = nn.Sequential(nn.Linear(tdim, tdim*4), nn.SiLU(), nn.Linear(tdim*4, tdim))
    def forward(self, x, t):
        """
        Args:
            x (torch.Tensor): Noisy images (B,1,32,32).
            t (torch.LongTensor): Timesteps (B,).

        Returns:
            (torch.Tensor): Predicted noise (B,1,32,32).
        """
        t_emb = self.tproj(sinusoidal_embedding(t, self.tproj[0].in_features))
        h1 = self.inp(x); h1 = self.rb1(h1, t_emb)
        h2 = self.down(h1); h2 = self.rb2(h2, t_emb)
        h  = self.mid(h2, t_emb)
        h  = self.up(h); h = h + h1
        h  = self.rb3(h, t_emb)
        return self.out(h)

# ----------------------- loss and sampling -----------------------

def diffusion_loss(model, x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
    """
    Args:
        model (nn.Module): Noise predictor ε_θ(x_t,t).
        x0 (torch.Tensor): Clean images in [-1,1], (B,1,H,W).
        t (torch.LongTensor): Timesteps (B,).
        sqrt_alphas_cumprod (torch.Tensor): Precomputed (T,).
        sqrt_one_minus_alphas_cumprod (torch.Tensor): Precomputed (T,).

    Returns:
        (torch.Tensor): MSE loss between true and predicted noise.
    """
    noise = torch.randn_like(x0)
    xt = q_sample(x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise)
    pred = model(xt, t)
    return F.mse_loss(pred, noise)

@torch.no_grad()
def sample_ddpm(model, shape, betas, alphas, alphas_cumprod, alphas_cumprod_prev, device):
    """
    Args:
        model (nn.Module): Trained noise predictor.
        shape (tuple): (B,1,32,32).
        betas, alphas, alphas_cumprod, alphas_cumprod_prev (torch.Tensor): Precomputed schedules.
        device (torch.device): Device.

    Returns:
        (torch.Tensor): Samples in [-1,1], shape (B,1,32,32).
    """
    T = betas.shape[0]
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas).to(device)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod).to(device)
    posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)

    x = torch.randn(shape, device=device)
    for t in reversed(range(T)):
        tt = torch.full((shape[0],), t, device=device, dtype=torch.long)
        eps = model(x, tt)
        x0_pred = (x - eps * sqrt_one_minus_alphas_cumprod[t]) / torch.sqrt(alphas_cumprod[t])
        coef1 = betas[t] * torch.sqrt(alphas_cumprod_prev[t]) / (1 - alphas_cumprod[t])
        coef2 = (torch.sqrt(alphas[t]) * (1 - alphas_cumprod_prev[t])) / (1 - alphas_cumprod[t])
        mean = coef1 * x0_pred + coef2 * x
        if t > 0:
            noise = torch.randn_like(x)
            x = mean + torch.sqrt(posterior_variance[t]) * noise
        else:
            x = mean
    return x.clamp(-1, 1)




In [4]:
# ----------------------- training script -----------------------

# def main():
#     """
#     Minimal training on MNIST (resized to 32x32) and sampling a grid.

#     Saves:
#         samples.png: A 8x8 grid of generated digits.
#         ckpt.pt: Model checkpoint after training.
#     """
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)

# data
tfm = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
ds = datasets.MNIST(root='./data', train=True, download=True, transform=tfm)
dl = torch.utils.data.DataLoader(ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

# schedule (shorter T for pedagogy)
T = 300
# betas = linear_beta_schedule(T).to(device)
beta_start=1e-4
beta_end=0.02
betas = torch.linspace(beta_start, beta_end, T).to(device) # linear schedule
alphas = (1.0 - betas)
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1, device=device), alphas_cumprod[:-1]])

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

# model & optim
model = TinyUNet(ch=64, tdim=64).to(device)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)

# train (few epochs; adjust as desired)
model.train()
for epoch in range(5):
    for x,_ in dl:
        x = x.to(device)  # already in [-1,1] via Normalize(0.5,0.5)
        t = torch.randint(0, T, (x.size(0),), device=device, dtype=torch.long)
        loss = diffusion_loss(model, x, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
        opt.zero_grad(); loss.backward(); opt.step()
    print(f"epoch {epoch+1}: loss={loss.item():.4f}")

torch.save(model.state_dict(), "ckpt.pt")

# sample and save
model.eval()
samples = sample_ddpm(model, (64,1,32,32), betas, alphas, alphas_cumprod, alphas_cumprod_prev, device)
# utils.save_image((samples+1)/2, "samples.png", nrow=8)  # back to [0,1]
# print("wrote samples.png")

# if __name__ == "__main__":
#     main()

100%|██████████| 9.91M/9.91M [00:00<00:00, 31.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 820kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 7.08MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.34MB/s]


epoch 1: loss=0.0518
epoch 2: loss=0.0395
epoch 3: loss=0.0501
epoch 4: loss=0.0406


libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1279d3420>
Traceback (most recent call last):
  File "/Users/william/mamba/envs/gene/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1662, in __del__
    def __del__(self):
  File "/Users/william/mamba/envs/gene/lib/python3.13/site-packages/torch/utils/data/_utils/signal_handling.py", line 73, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 66376) is killed by signal: Abort trap: 6. 


KeyboardInterrupt: 

In [2]:
# Minimal DDPM on MNIST (pure PyTorch, ~120 lines)
import math, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

# ----------------------- tiny UNet -----------------------
class Block(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_out, 3, padding=1), nn.GroupNorm(8, c_out), nn.SiLU(),
            nn.Conv2d(c_out, c_out, 3, padding=1), nn.GroupNorm(8, c_out), nn.SiLU(),
        )
    def forward(self, x): return self.net(x)

class TinyUNet(nn.Module):
    """
    Args:
        n_steps (int): Number of diffusion steps for sinusoidal time embedding.
    """
    def __init__(self, n_steps=200):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(64, 128), nn.SiLU(),
            nn.Linear(128, 128),
        )
        def tembed(t, dim=64):
            half = dim//2
            freqs = torch.exp(torch.linspace(0, math.log(10000), half, device=t.device))
            args = t[:,None]/n_steps
            emb = torch.cat([torch.sin(args*freqs), torch.cos(args*freqs)], dim=-1)
            return emb
        self.tembed = tembed

        self.inp  = nn.Conv2d(1, 32, 3, padding=1)
        self.b1   = Block(32, 64)
        self.down = nn.Conv2d(64, 64, 4, 2, 1)
        self.b2   = Block(64, 128)
        self.mid  = Block(128, 128)
        self.up   = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.b3   = Block(64+64, 64)
        self.out  = nn.Conv2d(64, 1, 1)

    def forward(self, x, t):
        """
        Args:
            x (Tensor): Noisy images (N,1,28,28).
            t (Tensor): Integer timesteps (N,).
        Returns:
            Tensor: Predicted noise ε_θ(x_t,t) with same shape as x.
        """
        temb = self.time_mlp(self.tembed(t))[:, :, None, None]
        h0 = F.silu(self.inp(x))
        h1 = self.b1(h0 + temb)
        h2 = self.down(h1)
        h3 = self.b2(h2 + temb)
        m  = self.mid(h3 + temb)
        u  = self.up(m)
        u  = self.b3(torch.cat([u, h1], 1) + temb)
        return self.out(u)

# ----------------------- diffusion utils -----------------------
class Diffusion:
    """
    Args:
        n_steps (int): Number of diffusion steps.
        beta_start (float): Start of linear beta schedule.
        beta_end (float): End of linear beta schedule.
        device (str): 'cuda' or 'cpu'.
    Attributes:
        betas (Tensor): β_t schedule.
        alphas_cumprod (Tensor): ∏_{s<=t} (1-β_s).
        alphas_cumprod_prev (Tensor): ∏_{s<=t-1} (1-β_s).
        sqrt_alphas_cumprod (Tensor)
        sqrt_one_minus_alphas_cumprod (Tensor)
        posterior_variance (Tensor): q(x_{t-1}|x_t,x_0) variance.
    """
    def __init__(self, n_steps=200, beta_start=1e-4, beta_end=0.02, device="cpu"):
        self.device = device
        self.n_steps = n_steps
        self.betas = torch.linspace(beta_start, beta_end, n_steps, device=device)
        alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), self.alphas_cumprod[:-1]])
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)

    def q_sample(self, x0, t, noise=None):
        """Forward diffusion q(x_t|x_0): add noise at step t."""
        if noise is None: noise = torch.randn_like(x0)
        s1 = self.index(self.sqrt_alphas_cumprod, t, x0.shape)
        s2 = self.index(self.sqrt_one_minus_alphas_cumprod, t, x0.shape)
        return s1 * x0 + s2 * noise

    def p_sample(self, model, x, t):
        """Single reverse step using ε-prediction parameterization."""
        betat = self.index(self.betas, t, x.shape)
        ac_t  = self.index(self.alphas_cumprod, t, x.shape)
        sqrt_one_minus_ac_t = self.index(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
        eps = model(x, t)
        x0_hat = (x - sqrt_one_minus_ac_t * eps) / torch.sqrt(ac_t)
        mean = (1/torch.sqrt(1 - betat))*(x - betat/torch.sqrt(1 - ac_t) * eps)
        var = self.index(self.posterior_variance, t, x.shape)
        if (t == 0).all(): return mean
        return mean + torch.sqrt(var) * torch.randn_like(x)

    def sample(self, model, n, shape):
        """Draw n samples by iterating t=T-1..0."""
        model.eval()
        x = torch.randn(n, *shape, device=self.device)
        for t in reversed(range(self.n_steps)):
            tt = torch.full((n,), t, device=self.device, dtype=torch.long)
            with torch.no_grad():
                x = self.p_sample(model, x, tt)
        return x

    @staticmethod
    def index(a, t, x_shape):
        return a.gather(-1, t).reshape(-1, *([1]*(len(x_shape)-1)))

# # ----------------------- training script -----------------------
# def main():
#     """
#     Minimal training loop. Produces samples in ./ddpm_samples.png.

#     Returns:
#         None
#     """
#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     T = 200
#     bs = 128
#     epochs = 1  # bump to 5–10 for better quality
#     lr = 2e-4

#     tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x*2-1)])  # [-1,1]
#     ds  = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
#     dl  = DataLoader(ds, batch_size=bs, shuffle=True, num_workers=2, drop_last=True)

#     net = TinyUNet(n_steps=T).to(device)
#     diff = Diffusion(n_steps=T, device=device)
#     opt = torch.optim.AdamW(net.parameters(), lr=lr)

#     for epoch in range(epochs):
#         for x,_ in dl:
#             x = x.to(device)
#             t = torch.randint(0, T, (x.size(0),), device=device).long()
#             noise = torch.randn_like(x)
#             x_t = diff.q_sample(x, t, noise)
#             pred = net(x_t, t)
#             loss = F.mse_loss(pred, noise)
#             opt.zero_grad(); loss.backward(); opt.step()
#         print(f"epoch {epoch+1} | loss={loss.item():.4f}")

#     with torch.no_grad():
#         imgs = diff.sample(net, n=64, shape=(1,28,28))
#         imgs = (imgs.clamp(-1,1)+1)/2
#     return imgs
# #         utils.save_image(imgs, "ddpm_samples.png", nrow=8)
# #     print("Wrote ddpm_samples.png")

# # if __name__ == "__main__":
# #     main()


In [None]:
T = 200 # number of diffusion steps
bs = 128 # batch size 
epochs = 1  # bump to 5–10 for better quality
lr = 2e-4 # learning rate
## If torch is able to find a GPU, use it. Otherwise, use the CPU.
device = "cuda" if torch.cuda.is_available() else "cpu"

ds  = datasets.MNIST(root="./", train=True, download=True, transform=tfm)
dl  = DataLoader(ds, batch_size=bs, shuffle=True, num_workers=2, drop_last=True)

net = TinyUNet(n_steps=T).to(device)
diff = Diffusion(n_steps=T, device=device)
opt = torch.optim.AdamW(net.parameters(), lr=lr)

for epoch in range(epochs):
    for x,_ in dl:
        x = x.to(device)
        t = torch.randint(0, T, (x.size(0),), device=device).long()
        noise = torch.randn_like(x)
        x_t = diff.q_sample(x, t, noise)
        pred = net(x_t, t)
        loss = F.mse_loss(pred, noise)
        opt.zero_grad(); loss.backward(); opt.step()
    print(f"epoch {epoch+1} | loss={loss.item():.4f}")

with torch.no_grad():
    imgs = diff.sample(net, n=64, shape=(1,28,28))
    imgs = (imgs.clamp(-1,1)+1)/2

100%|██████████| 9.91M/9.91M [00:00<00:00, 23.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 722kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 8.15MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.94MB/s]


PicklingError: Can't pickle <function <lambda> at 0x34fd898a0>: attribute lookup <lambda> on __main__ failed