# Denoising Diffusion Probabilistic Model (DDPM) Implementation

This notebook implements the diffusion model described in the paper (2206.00364v2.pdf). It will guide you through the process of building, training, and sampling from a DDPM using PyTorch.

## 1. Install and Import Required Libraries

We will use PyTorch, torchvision, numpy, matplotlib, and tqdm. If running locally, ensure these packages are installed.

In [None]:
# Install required packages (uncomment if running in a new environment)
# !pip install torch torchvision matplotlib tqdm numpy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

## 2. Load and Preprocess Dataset

We will use the MNIST dataset for demonstration. The images will be normalized to [-1, 1] as required by most diffusion models.

In [None]:
# Load and preprocess MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 2. - 1.)  # Scale to [-1, 1]
])

batch_size = 128
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# Visualize a batch
examples = next(iter(train_loader))[0][:8]
fig, axes = plt.subplots(1, 8, figsize=(12, 2))
for i, img in enumerate(examples):
    axes[i].imshow(img.squeeze().numpy(), cmap='gray')
    axes[i].axis('off')
plt.show()

## 3. Define the Diffusion Model Architecture

We will use a simple U-Net-like architecture suitable for MNIST. For more complex datasets, a deeper U-Net or transformer-based model may be used.

In [None]:
# Simple U-Net-like model for MNIST
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.middle = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU()
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 2, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU()
        )
        self.out = nn.Conv2d(32, 1, 1)

    def forward(self, x, t):
        # t is the timestep, can be embedded and concatenated if desired
        d1 = self.down1(x)
        d2 = self.down2(d1)
        m = self.middle(d2)
        u2 = self.up2(m)
        u1 = self.up1(u2)
        out = self.out(u1)
        return out

model = SimpleUNet()
print(model)

## 4. Implement the Forward Diffusion Process

The forward process gradually adds Gaussian noise to the data over a fixed number of timesteps.

In [None]:
# Forward diffusion process
T = 200  # Number of diffusion steps
beta_start = 1e-4
beta_end = 0.02
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1. - betas
alpha_bars = torch.cumprod(alphas, dim=0)

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alpha_bar = alpha_bars[t].sqrt().view(-1, 1, 1, 1)
    sqrt_one_minus_alpha_bar = (1 - alpha_bars[t]).sqrt().view(-1, 1, 1, 1)
    return sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise

# Visualize noisy images at different timesteps
x = examples[:4]
timesteps = torch.tensor([0, T//4, T//2, T-1])
noisy_imgs = [q_sample(x, torch.full((x.size(0),), t, dtype=torch.long)) for t in timesteps]
fig, axes = plt.subplots(1, 4, figsize=(10, 2))
for i, img in enumerate(noisy_imgs):
    axes[i].imshow(img[0].squeeze().numpy(), cmap='gray')
    axes[i].set_title(f"t={timesteps[i].item()}")
    axes[i].axis('off')
plt.show()

## 5. Implement the Reverse (Denoising) Process

The reverse process uses the model to predict the noise at each timestep and denoise the image step by step.

In [None]:
# Reverse (denoising) process for sampling
def p_sample(model, x, t):
    beta = betas[t]
    sqrt_one_minus_alpha_bar = (1 - alpha_bars[t]).sqrt()
    sqrt_recip_alpha = (1. / alphas[t]).sqrt()
    model_mean = sqrt_recip_alpha * (x - beta / sqrt_one_minus_alpha_bar * model(x, torch.tensor([t])))
    if t > 0:
        noise = torch.randn_like(x)
        return model_mean + beta.sqrt() * noise
    else:
        return model_mean

def p_sample_loop(model, shape):
    device = next(model.parameters()).device
    x = torch.randn(shape, device=device)
    for t in reversed(range(T)):
        x = p_sample(model, x, t)
    return x

## 6. Train the Diffusion Model

Set up the training loop, loss function, and optimizer. The model is trained to predict the noise added at each timestep.

In [None]:
# Training loop for DDPM
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-4)
epochs = 1  # For demonstration; increase for better results

for epoch in range(epochs):
    pbar = tqdm(train_loader)
    for x, _ in pbar:
        x = x.to(device)
        t = torch.randint(0, T, (x.size(0),), device=device).long()
        noise = torch.randn_like(x)
        x_noisy = q_sample(x, t, noise)
        noise_pred = model(x_noisy, t)
        loss = F.mse_loss(noise_pred, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_description(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

## 7. Generate Samples with the Trained Model

Use the trained model to generate new samples by running the reverse diffusion process starting from random noise.

In [None]:
# Generate samples from the trained model
model.eval()
with torch.no_grad():
    samples = p_sample_loop(model, (8, 1, 28, 28)).cpu()
    samples = (samples + 1) / 2  # Rescale to [0, 1]

fig, axes = plt.subplots(1, 8, figsize=(12, 2))
for i, img in enumerate(samples):
    axes[i].imshow(img.squeeze().numpy(), cmap='gray')
    axes[i].axis('off')
plt.show()

## 8. Visualize Generated Samples

Display the generated samples to assess the performance of the diffusion model.

# Elucidated Diffusion Model (EDM) Implementation

This section implements the EDM as described in the paper (2206.00364v2.pdf), including its unique noise schedule, loss weighting, and sampling procedure.

In [None]:
# EDM noise schedule, loss weighting, and sampling procedure
# Reference: 2206.00364v2.pdf (Karras et al., 2022)

# EDM uses a continuous noise schedule and loss weighting
# We use sigma_min, sigma_max, and rho as in the paper
sigma_min = 0.002
sigma_max = 80
rho = 7

# EDM noise schedule (sampling)
def edm_sigma_schedule(t):
    return (sigma_max ** (1/rho) + t * (sigma_min ** (1/rho) - sigma_max ** (1/rho))) ** rho

# EDM loss weighting
# w(t) = (sigma^2 + 1) / (sigma^2)
def edm_loss_weight(sigma):
    return (sigma ** 2 + 1) / (sigma ** 2)

# Example: plot the EDM noise schedule
ts = torch.linspace(0, 1, 100)
sigmas = edm_sigma_schedule(ts)
plt.plot(ts.numpy(), sigmas.numpy())
plt.xlabel('t')
plt.ylabel('sigma(t)')
plt.title('EDM Noise Schedule')
plt.show()

In [None]:
# EDM training loop (for MNIST, using the same U-Net model)
# This uses the EDM noise schedule and loss weighting

edm_epochs = 1  # For demonstration; increase for better results
model_edm = SimpleUNet().to(device)
optimizer_edm = optim.Adam(model_edm.parameters(), lr=2e-4)

for epoch in range(edm_epochs):
    pbar = tqdm(train_loader)
    for x, _ in pbar:
        x = x.to(device)
        t = torch.rand(x.size(0), device=device)  # Uniform in [0, 1]
        sigma = edm_sigma_schedule(t)
        noise = torch.randn_like(x)
        x_noisy = x + sigma.view(-1, 1, 1, 1) * noise
        noise_pred = model_edm(x_noisy, (sigma * 1000).long())  # Pass scaled sigma as timestep
        weight = edm_loss_weight(sigma).view(-1, 1, 1, 1)
        loss = ((noise_pred - noise) ** 2 * weight).mean()
        optimizer_edm.zero_grad()
        loss.backward()
        optimizer_edm.step()
        pbar.set_description(f"EDM Epoch {epoch+1} Loss: {loss.item():.4f}")

In [None]:
# EDM sampling procedure (ancestral sampling)
def edm_ancestral_sampling(model, num_steps=18, batch_size=8, img_shape=(1, 28, 28)):
    device = next(model.parameters()).device
    x = torch.randn((batch_size,) + img_shape, device=device) * sigma_max
    sigmas = edm_sigma_schedule(torch.linspace(1, 0, num_steps, device=device))
    for i in range(num_steps):
        sigma = sigmas[i]
        sigma_next = sigmas[i+1] if i+1 < num_steps else torch.tensor(sigma_min, device=device)
        c_in = 1 / (sigma ** 2 + 1).sqrt()
        c_out = sigma_next / sigma
        c_noise = sigma * 1000  # Scale for timestep embedding
        noise_pred = model(c_in * x, torch.full((batch_size,), c_noise, device=device).long())
        d = (x - sigma * noise_pred) / sigma
        x = x + (sigma_next - sigma) * d
        if sigma_next > 0:
            x = x + torch.randn_like(x) * (sigma_next ** 2 - sigma ** 2).sqrt()
    return x

In [None]:
# Generate and visualize samples from the trained EDM model
model_edm.eval()
with torch.no_grad():
    edm_samples = edm_ancestral_sampling(model_edm, batch_size=8, img_shape=(1, 28, 28)).cpu()
    edm_samples = (edm_samples + 1) / 2  # Rescale to [0, 1]

fig, axes = plt.subplots(1, 8, figsize=(12, 2))
for i, img in enumerate(edm_samples):
    axes[i].imshow(img.squeeze().numpy(), cmap='gray')
    axes[i].axis('off')
plt.suptitle('EDM Generated Samples')
plt.show()