# Exercise: Training a Simple DDPM on MNIST

**Objective:** Implement and train a Denoising Diffusion Probabilistic Model (DDPM) on MNIST with smooth MSE convergence.

**Key Learning:**
- Understand noise schedules and forward diffusion
- Implement time embeddings for diffusion conditioning
- Build a U-Net for noise prediction
- Train with MSE loss (demonstrating stable convergence vs GANs)
- Analyze convergence curves



## Section 1: Import Required Libraries

Import necessary libraries for building and training the diffusion model.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

# Device selection
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


## Section 2: Load and Prepare MNIST Dataset

Load MNIST, normalize to [-1, 1], create data loaders, and visualize samples.

In [None]:
# Data preprocessing: Normalize to [-1, 1]
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

# Load MNIST
train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)

print(f"Train dataset: {len(train_dataset)} samples")

# Visualize a few samples
fig, axes = plt.subplots(1, 4, figsize=(12, 3))
for i in range(4):
    img, label = train_dataset[i]
    axes[i].imshow(img.squeeze(), cmap="gray")
    axes[i].set_title(f"Label: {label}")
    axes[i].axis("off")
plt.tight_layout()
plt.show()


## Section 3: TODO 1 - Define Noise Scheduler

**TODO 1:** Complete the `NoiseScheduler.__init__` method.

Implement linear beta schedule, compute alphas, cumulative products, and pre-compute square roots.

In [None]:
class NoiseScheduler:
    """Noise scheduler for forward diffusion process."""

    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        """
        TODO 1: Initialize noise scheduler.

        Requirements:
        1. Create linear beta schedule from beta_start to beta_end
        2. Compute alphas (1 - beta)
        3. Compute cumulative products (alphas_cumprod)
        4. Pre-compute square roots for efficiency

        Mathematical formulas:
        - β_t: Linear schedule from beta_start to beta_end
        - α_t = 1 - β_t
        - ᾱ_t = ∏_{i=0}^t α_i (cumulative product)
        - √ᾱ_t and √(1-ᾱ_t) are used in forward diffusion formula
        """
        self.num_timesteps = num_timesteps

        # TODO 1: Replace pass with implementation
        pass

    def get_coefficients(self, timesteps):
        """Get pre-computed coefficients for a batch of timesteps."""
        sqrt_alphas = self.sqrt_alphas_cumprod[timesteps]
        sqrt_one_minus_alphas = self.sqrt_one_minus_alphas_cumprod[timesteps]

        # Reshape for broadcasting
        if len(sqrt_alphas.shape) == 1:
            sqrt_alphas = sqrt_alphas[:, None, None, None]
            sqrt_one_minus_alphas = sqrt_one_minus_alphas[:, None, None, None]

        return sqrt_alphas, sqrt_one_minus_alphas


# Test the scheduler
try:
    scheduler = NoiseScheduler()
    print("✓ NoiseScheduler initialized")
    print(f"  Timesteps: {scheduler.num_timesteps}")
except Exception as e:
    print(f"✗ Error: {e}")


## Section 4: TODO 2 - Forward Diffusion Process

**TODO 2:** Implement the `add_noise` function for forward diffusion.

Use the formula: $x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon$

In [None]:
def add_noise(x_0, timestep, scheduler, noise=None):
    """
    TODO 2: Add noise to images (forward diffusion).

    Formula: x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε

    Args:
        x_0: Original images (batch, 1, 28, 28)
        timestep: Timesteps (batch,)
        scheduler: NoiseScheduler instance
        noise: Optional pre-generated noise

    Returns:
        x_t: Noisy images (batch, 1, 28, 28)
        noise: The noise added (for training)

    Implementation steps:
    1. If noise is None, create random noise using torch.randn_like()
    2. Get coefficients from scheduler using get_coefficients()
    3. Apply forward diffusion formula
    4. Return noisy image and noise
    """

    # TODO 2: Replace pass with implementation
    pass


# Test forward diffusion
try:
    x_0 = torch.randn(4, 1, 28, 28)
    t = torch.tensor([0, 250, 500, 999])
    x_t, noise = add_noise(x_0, t, scheduler)
    print("✓ Forward diffusion working")
    print(f"  x_0 shape: {x_0.shape}, range: [{x_0.min():.2f}, {x_0.max():.2f}]")
    print(f"  x_t shape: {x_t.shape}, range: [{x_t.min():.2f}, {x_t.max():.2f}]")
except Exception as e:
    print(f"✗ Error: {e}")


## Section 5: TODO 3-5 - Build U-Net Architecture

**TODO 3:** Implement `TimeEmbedding` with sinusoidal encoding
**TODO 4:** Implement `ResidualBlock` with FiLM time conditioning
**TODO 5:** Implement `SimpleUNet` forward pass

In [None]:
class TimeEmbedding(nn.Module):
    """TODO 3: Encode timestep into sinusoidal embedding."""

    def __init__(self, embedding_dim=128):
        super().__init__()
        self.embedding_dim = embedding_dim

    def forward(self, timestep):
        """
        TODO 3: Implement sinusoidal time embedding.

        Formula creates frequency schedule:
        freqs = exp(-ln(10000) * k / d) for k in [0, d)

        Then embedding = [sin(t * freqs), cos(t * freqs)]

        Args:
            timestep: LongTensor of shape (batch,)

        Returns:
            embedding: FloatTensor of shape (batch, embedding_dim)
        """
        device = timestep.device
        half_dim = self.embedding_dim // 2

        # TODO 3: Replace pass with implementation
        pass


class ResidualBlock(nn.Module):
    """TODO 4: Residual block with FiLM time conditioning."""

    def __init__(self, in_channels, out_channels, embedding_dim=128):
        super().__init__()

        self.norm1 = nn.GroupNorm(32, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        # FiLM: Feature-wise Linear Modulation
        self.time_mlp = nn.Sequential(
            nn.Linear(embedding_dim, out_channels * 2),
            nn.SiLU(),
            nn.Linear(out_channels * 2, out_channels),
        )

        self.norm2 = nn.GroupNorm(32, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x, time_embedding):
        """
        TODO 4: Implement FiLM-modulated residual block.

        FiLM applies time-dependent scaling: h = h * time_scale

        Implementation:
        1. Apply norm1 → SiLU → conv1
        2. Compute time_scale via time_mlp
        3. Reshape time_scale to (batch, channels, 1, 1)
        4. Apply FiLM: h = h * time_scale
        5. Apply norm2 → SiLU → conv2
        6. Add skip connection: h + skip(x)
        """
        # TODO 4: Replace pass with implementation
        pass


class SimpleUNet(nn.Module):
    """TODO 5: Simple U-Net for noise prediction."""

    def __init__(self, image_channels=1, base_channels=64):
        super().__init__()
        self.time_embedding = TimeEmbedding(embedding_dim=128)

        self.init_conv = nn.Conv2d(
            image_channels, base_channels, kernel_size=3, padding=1
        )

        # Encoder (downsampling)
        self.down_res1 = ResidualBlock(base_channels, base_channels, embedding_dim=128)
        self.down_conv1 = nn.Conv2d(
            base_channels, base_channels, kernel_size=4, stride=2, padding=1
        )

        self.down_res2 = ResidualBlock(
            base_channels, base_channels * 2, embedding_dim=128
        )
        self.down_conv2 = nn.Conv2d(
            base_channels * 2, base_channels * 2, kernel_size=4, stride=2, padding=1
        )

        # Bottleneck
        self.middle_res = ResidualBlock(
            base_channels * 2, base_channels * 2, embedding_dim=128
        )

        # Decoder (upsampling)
        self.up_conv2 = nn.ConvTranspose2d(
            base_channels * 2, base_channels, kernel_size=4, stride=2, padding=1
        )
        self.up_res2 = ResidualBlock(base_channels, base_channels, embedding_dim=128)

        self.up_conv1 = nn.ConvTranspose2d(
            base_channels, base_channels, kernel_size=4, stride=2, padding=1
        )
        self.up_res1 = ResidualBlock(base_channels, base_channels, embedding_dim=128)

        # Final
        self.final_norm = nn.GroupNorm(32, base_channels)
        self.final_conv = nn.Conv2d(
            base_channels, image_channels, kernel_size=3, padding=1
        )

    def forward(self, x, timestep):
        """
        TODO 5: Implement U-Net forward pass.

        Architecture:
        Input (28×28) → init_conv (64 channels)
        → Encoder: down_res1 → down_conv1 (28→14)
                 → down_res2 → down_conv2 (14→7)
        → Middle: middle_res (7×7)
        → Decoder: up_conv2 (7→14)
                 → up_res2 → up_conv1 (14→28)
                 → up_res1
        → Final: final_conv (1 channel)

        Pass time_embedding to each ResidualBlock.
        """
        time_emb = self.time_embedding(timestep)

        # TODO 5: Replace pass with implementation
        pass


# Test U-Net
try:
    model = SimpleUNet().to(device)
    x = torch.randn(4, 1, 28, 28).to(device)
    t = torch.tensor([0, 250, 500, 999]).to(device)
    out = model(x, t)
    print("U-Net forward pass working")
    print(f"  Input shape: {x.shape}")
    print(f"  Output shape: {out.shape}")
except Exception as e:
    print(f" Error: {e}")


## Section 6: TODO 6 - Training Step

**TODO 6:** Implement single training iteration with MSE loss.

In [None]:
def train_step(model, optimizer, x_0, scheduler, device):
    """
    TODO 6: Single training iteration.

    Steps:
    1. Move data to device
    2. Sample random timesteps
    3. Sample random noise
    4. Apply forward diffusion to get x_t
    5. Predict noise with model
    6. Compute MSE loss
    7. Backward pass and step

    Returns:
        loss_value: MSE loss for this batch
    """
    model.train()
    x_0 = x_0.to(device)
    batch_size = x_0.shape[0]

    # TODO 6: Replace pass with implementation
    pass


# Test train_step
try:
    model = SimpleUNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    batch = torch.randn(4, 1, 28, 28)
    loss = train_step(model, optimizer, batch, scheduler, device)
    print("✓ Training step working")
    print(f"  Loss: {loss:.6f}")
except Exception as e:
    print(f"✗ Error: {e}")


## Section 7: TODO 7 - Training Loop

**TODO 7:** Implement training for one epoch (loop over all batches).

In [None]:
def train_epoch(model, optimizer, train_loader, scheduler, device):
    """
    TODO 7: Train for one epoch.

    Loop through all batches:
    1. Call train_step for each batch
    2. Accumulate losses
    3. Return average loss

    Returns:
        avg_loss: Average MSE loss for the epoch
    """
    model.train()
    total_loss = 0.0
    num_batches = 0

    # TODO 7: Replace pass with implementation
    pass


print("train_epoch function defined (TODO 7)")


## Section 8: TODO 8 - Complete Training Script

**TODO 8:** Implement main training loop with device setup and visualization.

In [None]:
def main():
    """
    TODO 8: Complete training script.

    Steps:
    1. Device selection (already done, use global 'device')
    2. Create model
    3. Setup optimizer (learning rate = 0.001)
    4. Training loop for 10 epochs
    5. Plot and analyze convergence

    Key point: Observe SMOOTH convergence (unlike cGAN)
    """

    # Model and optimizer
    model = SimpleUNet(image_channels=1, base_channels=64).to(device)
    scheduler_obj = NoiseScheduler(num_timesteps=1000)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    num_epochs = 10
    epoch_losses = []

    print("\nStarting DDPM Training (10 epochs)...")
    print("-" * 50)

    # TODO 8: Replace pass with implementation
    pass

    # Plot results
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(
        range(1, num_epochs + 1), epoch_losses, marker="o", linewidth=2, markersize=8
    )
    plt.xlabel("Epoch", fontsize=12)
    plt.ylabel("MSE Loss", fontsize=12)
    plt.title("DDPM: Smooth Convergence\n(Compare with cGAN volatility)", fontsize=12)
    plt.grid(True, alpha=0.3)

    # Analysis
    plt.subplot(1, 2, 2)
    diffs = [
        abs(epoch_losses[i + 1] - epoch_losses[i]) for i in range(len(epoch_losses) - 1)
    ]
    plt.bar(range(1, len(diffs) + 1), diffs, alpha=0.7)
    plt.xlabel("Epoch", fontsize=12)
    plt.ylabel("Absolute Loss Change", fontsize=12)
    plt.title("Loss Stability\n(Small changes = Smooth convergence)", fontsize=12)
    plt.grid(True, axis="y", alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print analysis
    print("\n" + "=" * 50)
    print("Convergence Analysis")
    print("=" * 50)
    print(f"Initial Loss: {epoch_losses[0]:.6f}")
    print(f"Final Loss:   {epoch_losses[-1]:.6f}")
    improvement = (epoch_losses[0] - epoch_losses[-1]) / epoch_losses[0] * 100
    print(f"Improvement:  {improvement:.1f}%")
    print(f"Max jump:     {max(diffs):.6f}")
    print(f"Avg jump:     {sum(diffs)/len(diffs):.6f}")

    if max(diffs) < 0.1 * epoch_losses[0]:
        print("\n Convergence is SMOOTH!")
    else:
        print("\n High volatility detected")


# Run training
# Note: Uncomment the line below to run training
# main()


## Summary

**Key Learning Points:**

1. **Stable Convergence:** MSE loss guarantees smooth, monotonic convergence
2. **Noise Schedules:** Fixed schedule (no learning) simplifies training
3. **Time Conditioning:** Embedding enables diffusion at all timesteps
4. **U-Net Architecture:** Encoder-decoder preserves spatial information
5. **Comparison with cGAN:** Diffusion is fundamentally more stable

**Next Steps:**

- Train for 50+ epochs for better quality
- Implement conditional generation (add class labels)
- Try different noise schedules (cosine, sigmoid)
- Compare with Module 13's cGAN volatility

