# Exercise: Comparing Sampling Schedules


In this notebook, you'll implement and compare two different sampling schedules for reverse diffusion:
- **Linear Schedule**: Uniform timestep spacing
- **Cosine Schedule**: Non-linear spacing focusing on high-noise steps

By the end, you'll understand how scheduling affects generation quality without changing computational cost.

**Learning Objectives:**
+ Implement timestep scheduling strategies
+ Understand reverse diffusion step-by-step  
+ Compare quality-speed tradeoffs
+ Analyze sampling schedule impact



## Section 1: Import Required Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import torchvision
from torchvision import transforms, datasets
import math
import json
from pathlib import Path

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Load Fashion MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

fashion_mnist = datasets.FashionMNIST(
    root="./data", train=False, download=True, transform=transform
)

print(f"Fashion MNIST loaded: {len(fashion_mnist)} images")
print(f"Image shape: {fashion_mnist[0][0].shape}")


In [None]:
# Visualization helpers 
def plot_sample_comparison(linear_samples, cosine_samples, num_cols=4):
    """Display side-by-side comparison of linear vs cosine sampled images."""
    fig, axes = plt.subplots(2, num_cols, figsize=(16, 8))

    for i in range(num_cols):
        # Linear samples (top row)
        img_linear = (linear_samples[i] + 1) / 2  # Denormalize
        img_linear = torch.clamp(img_linear, 0, 1)
        axes[0, i].imshow(img_linear.cpu().squeeze(), cmap="gray")
        axes[0, i].set_title(f"Linear #{i+1}", fontsize=11, fontweight="bold")
        axes[0, i].axis("off")

        # Cosine samples (bottom row)
        img_cosine = (cosine_samples[i] + 1) / 2  # Denormalize
        img_cosine = torch.clamp(img_cosine, 0, 1)
        axes[1, i].imshow(img_cosine.cpu().squeeze(), cmap="gray")
        axes[1, i].set_title(f"Cosine #{i+1}", fontsize=11, fontweight="bold")
        axes[1, i].axis("off")

    plt.suptitle(
        "Sampling Schedule Comparison\nTop: Linear Schedule | Bottom: Cosine Schedule",
        fontsize=14,
        fontweight="bold",
        y=0.98,
    )
    plt.tight_layout()
    plt.show()


def plot_denoising_trajectory(trajectory, schedule_name="Schedule", num_steps=8):
    """Plot denoising trajectory showing progression from noise to image."""
    # Select evenly spaced steps from trajectory
    indices = np.linspace(0, len(trajectory) - 1, num_steps, dtype=int)
    selected_steps = [trajectory[i] for i in indices]

    fig, axes = plt.subplots(1, num_steps, figsize=(16, 3))

    for idx, (step_idx, img) in enumerate(selected_steps):
        img_display = (img + 1) / 2  # Denormalize
        img_display = torch.clamp(img_display, 0, 1)
        axes[idx].imshow(img_display.cpu().squeeze(), cmap="gray")
        axes[idx].set_title(f"Step {step_idx}", fontsize=10, fontweight="bold")
        axes[idx].axis("off")

    plt.suptitle(
        f"Denoising Trajectory: {schedule_name}\nProgression from noise to sample",
        fontsize=12,
        fontweight="bold",
    )
    plt.tight_layout()
    plt.show()


def plot_schedule_metrics(linear_metrics, cosine_metrics):
    """Plot quality metrics comparison between schedules."""
    metrics = ["Variance", "Sharpness", "Mean"]
    x = np.arange(len(metrics))
    width = 0.35

    fig, ax = plt.subplots(figsize=(10, 6))

    linear_vals = [linear_metrics.get(m, 0) for m in metrics]
    cosine_vals = [cosine_metrics.get(m, 0) for m in metrics]

    ax.bar(x - width / 2, linear_vals, width, label="Linear Schedule", alpha=0.8)
    ax.bar(x + width / 2, cosine_vals, width, label="Cosine Schedule", alpha=0.8)

    ax.set_ylabel("Metric Value", fontsize=12)
    ax.set_title(
        "Sampling Schedule Comparison: Quality Metrics", fontsize=13, fontweight="bold"
    )
    ax.set_xticks(x)
    ax.set_xticklabels(metrics)
    ax.legend(fontsize=11)
    ax.grid(True, axis="y", alpha=0.3)

    plt.tight_layout()
    plt.show()


print(" Visualization helpers loaded")


## Section 2: Load Pre-trained U-Net Model

Import the model and load checkpoint from Module 16.

In [None]:
# TODO 1: Load pre-trained model
# For now, we'll create a simple placeholder model structure
# Replace with actual checkpoint loading from Module 16


# Define a simple UNet for demonstration
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.enc1 = nn.Conv2d(1, 64, 3, padding=1)
        self.enc2 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
        self.enc3 = nn.Conv2d(128, 256, 3, padding=1, stride=2)

        # Middle
        self.mid = nn.Conv2d(256, 256, 3, padding=1)

        # Decoder
        self.dec3 = nn.ConvTranspose2d(256, 128, 4, padding=1, stride=2)
        self.dec2 = nn.ConvTranspose2d(128, 64, 4, padding=1, stride=2)
        self.dec1 = nn.Conv2d(64, 1, 3, padding=1)

    def forward(self, x, t):
        # Simple UNet (in practice, t would be embedded and added at each layer)
        x1 = F.relu(self.enc1(x))
        x2 = F.relu(self.enc2(x1))
        x3 = F.relu(self.enc3(x2))

        x = F.relu(self.mid(x3))

        x = F.relu(self.dec3(x + x3))
        x = F.relu(self.dec2(x + x2))
        x = self.dec1(x)

        return x


# Load or create model
model = SimpleUNet().to(device)
model.eval()

# Try to load checkpoint if it exists
checkpoint_path = "../../lesson-16-Implementing-Simple-Diffusion-Model/checkpoint.pt"
try:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    print(f" Loaded checkpoint from {checkpoint_path}")
except FileNotFoundError:
    print(f" Could not find checkpoint at {checkpoint_path}")
    print("  Using untrained model (for demonstration only)")

print(f"Model loaded: {type(model).__name__}")
print(f"Device: {device}")


## Section 3: Implement DDPM Reverse Sampling Function

In [None]:
# Define noise schedule (pre-computed)
num_timesteps = 1000
betas = torch.linspace(0.0001, 0.02, num_timesteps, device=device)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1, device=device), alphas_cumprod[:-1]])

print(f"Noise schedule configured:")
print(f"  Total timesteps: {num_timesteps}")
print(f"  Alpha cumprod shape: {alphas_cumprod.shape}")
print(f"  Alphas cumprod range: [{alphas_cumprod[-1]:.6f}, {alphas_cumprod[0]:.6f}]")


# TODO 2: Implement reverse_diffusion_step
def reverse_diffusion_step(
    x_t, t, predicted_noise, alphas_cumprod, alphas_cumprod_prev, betas, device
):
    """
    Perform one reverse diffusion step: x_t → x_{t-1}

    Args:
        x_t: Current noisy image (batch, 1, 28, 28)
        t: Current timestep index (int)
        predicted_noise: U-Net prediction (batch, 1, 28, 28)
        alphas_cumprod: Pre-computed cumulative alphas
        alphas_cumprod_prev: Cumulative alphas at t-1
        betas: Beta schedule
        device: Compute device

    Returns:
        x_{t-1}: Denoised image

    TODO: Implement reverse diffusion formula
    1. Get alpha values for timestep t
    2. Predict original image from x_t and predicted noise
    3. Compute posterior mean using Bayes theorem
    4. Add variance (stochastic sampling)
    """

    # TODO: Your implementation here
    # Get alpha_t and alpha_t_prev
    # Compute posterior mean and variance
    # Return denoised sample

    pass


# Test stub (will work once implemented)
print("reverse_diffusion_step defined")


## Section 4: Implement Linear and Cosine Schedules

In [None]:
# TODO 3: Implement LinearSchedule
class LinearSchedule:
    """Linear (uniform) timestep schedule."""

    def __init__(self, num_train_timesteps=1000, num_inference_steps=50):
        """
        Initialize linear timestep schedule.

        TODO: Create linearly spaced timesteps from num_train_timesteps-1 to 0
        """
        # TODO: Your implementation here
        pass

    def __len__(self):
        return len(self.timesteps)


# TODO 4: Implement CosineSchedule
class CosineSchedule:
    """Cosine (non-linear) timestep schedule focusing on high-noise steps."""

    def __init__(self, num_train_timesteps=1000, num_inference_steps=50):
        """
        Initialize cosine timestep schedule.

        TODO: Create non-linear timesteps using cosine formula
        Formula: α_t = cos²(π * (t/N + s) / (1+s)) where s=0.008
        """
        # TODO: Your implementation here
        pass

    def __len__(self):
        return len(self.timesteps)


# Test schedule creation
print("Schedule classes defined (TODOs to complete)")


## Section 5: Implement Complete Sampling Loop

In [None]:
# TODO 5: Implement sample_with_schedule
def sample_with_schedule(
    model, schedule_type, num_steps, batch_size, device, return_trajectory=False
):
    """
    Complete sampling pipeline with chosen schedule.

    Args:
        model: U-Net model
        schedule_type: "linear" or "cosine"
        num_steps: Number of sampling steps
        batch_size: Number of images to generate
        device: Compute device
        return_trajectory: Whether to save intermediate steps

    Returns:
        samples: Generated images (batch, 1, 28, 28)
        trajectory: (optional) All intermediate steps

    TODO: Implement complete sampling loop
    1. Create schedule object (LinearSchedule or CosineSchedule)
    2. Initialize with pure Gaussian noise
    3. For each timestep:
       a. Use model to predict noise
       b. Call reverse_diffusion_step()
       c. Save intermediate if needed
    4. Return final samples
    """

    # TODO: Your implementation here
    pass


print("sample_with_schedule defined (TODO to complete)")


## Section 6: Implement Quality Metrics

In [None]:
# TODO 6: Implement compute_sample_variance
def compute_sample_variance(samples):
    """
    Measure diversity of generated samples.

    Args:
        samples: (batch, 1, 28, 28) tensor

    Returns:
        variance: Scalar indicating sample diversity

    TODO: Flatten samples, compute mean across batch, then variance.
    High variance = diverse samples (good)
    Low variance = similar samples (mode collapse)
    """
    # TODO: Your implementation here
    pass


# TODO 7: Implement compute_sharpness
def compute_sharpness(samples):
    """
    Estimate image clarity using Sobel edge detection.

    Args:
        samples: (batch, 1, 28, 28) tensor

    Returns:
        sharpness: Scalar indicating edge strength

    TODO: Apply Sobel filters and compute average edge magnitude.
    """
    # TODO: Your implementation here
    pass


print("Metric functions defined (TODOs to complete)")


## Section 7: Visualize Denoising Trajectories

In [None]:
# TODO 8: Use visualization helpers for comparison
# After implementing sampling above, use these helpers:

# 1. Compare final samples from both schedules
plot_sample_comparison(linear_samples, cosine_samples, num_cols=4)

# 2. Show denoising trajectory for each schedule
plot_denoising_trajectory(
    linear_trajectory, schedule_name="Linear Schedule", num_steps=8
)
plot_denoising_trajectory(
    cosine_trajectory, schedule_name="Cosine Schedule", num_steps=8
)

# 3. Compare quality metrics
linear_metrics = {
    "Variance": linear_variance,
    "Sharpness": linear_sharpness,
    "Mean": linear_samples.mean().item(),
}
cosine_metrics = {
    "Variance": cosine_variance,
    "Sharpness": cosine_sharpness,
    "Mean": cosine_samples.mean().item(),
}
plot_schedule_metrics(linear_metrics, cosine_metrics)


## Summary and Next Steps

### What You've Implemented

In this notebook, you've built the complete reverse diffusion sampling pipeline:

1. **LinearSchedule**: Uniform timestep spacing for fast generation
2.  **CosineSchedule**: Non-linear spacing focusing on hard (high-noise) steps
3.  **reverse_diffusion_step()**: Core algorithm for one denoising step
4.  **sample_with_schedule()**: Complete sampling loop with trajectory tracking
5.  **compute_sample_variance()**: Measure generation diversity
6.  **compute_sharpness()**: Measure image clarity

### Key Insights

- **Linear schedule**: Fast but lower quality per step
- **Cosine schedule**: Better quality with same computational cost
- **Sampling schedule** is a design choice that doesn't change model complexity
- **Quality-speed tradeoff** can be optimized through smart scheduling