# Demo: Reverse Diffusion Sampling and Image Generation
## Creating Images from Pure Noise with Fashion MNIST

In this demo, we'll explore the complete reverse diffusion sampling process. Starting from pure Gaussian noise, we'll progressively denoise to create coherent Fashion MNIST images.

**What You'll Learn:**
- How reverse diffusion recovers images from noise
- Comparing sampling schedules (linear vs cosine)
- Visualizing the denoising trajectory
- Understanding quality-speed tradeoffs



## Section 1: Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import torchvision
from torchvision import transforms, datasets
import math
from pathlib import Path
import sys

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(" Imports complete")


## Section 2: Load Pre-trained Model and Noise Schedule

In [None]:
# Initialize noise schedule (from training)
num_timesteps = 1000
betas = torch.linspace(0.0001, 0.02, num_timesteps, device=device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1, device=device), alphas_cumprod[:-1]])

print(f"✓ Noise Schedule Configured:")
print(f"  Total timesteps: {num_timesteps}")
print(f"  Alpha range: [{alphas_cumprod[-1]:.6f}, {alphas_cumprod[0]:.6f}]")


# Load pre-trained U-Net model from Module 16
# Create a simple placeholder model
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Conv2d(1, 64, 3, padding=1)
        self.enc2 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
        self.mid = nn.Conv2d(128, 128, 3, padding=1)
        self.dec = nn.ConvTranspose2d(128, 64, 4, padding=1, stride=2)
        self.out = nn.Conv2d(64, 1, 3, padding=1)

    def forward(self, x, t):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.mid(x))
        x = F.relu(self.dec(x))
        x = self.out(x)
        return x


model = SimpleUNet().to(device)
model.eval()

# Try to load checkpoint
checkpoint_path = "../../lesson-16-Implementing-Simple-Diffusion-Model/checkpoint.pt"
try:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint.get("model_state_dict", checkpoint))
    print(f" Loaded pre-trained model")
except:
    print(f" Using untrained model (checkpoint not found)")

print(f" Model ready on {device}")


## Section 3: Implement Reverse Diffusion Core Algorithm

In [None]:
def reverse_diffusion_step(x_t, t, pred_noise):
    """Single reverse diffusion step: x_t → x_{t-1}"""
    alpha_t = alphas_cumprod[t]
    alpha_t_prev = alphas_cumprod_prev[t]
    beta_t = betas[t]

    # Predict original image
    x_0_pred = (x_t - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t)

    # Posterior mean
    variance = (1 - alpha_t_prev) * beta_t / (1 - alpha_t)
    coef1 = torch.sqrt(alpha_t_prev) * beta_t / (1 - alpha_t)
    coef2 = (1 - beta_t) * torch.sqrt(1 - alpha_t_prev) / (1 - alpha_t)

    mean = coef1 * x_0_pred + coef2 * x_t

    # Add noise
    if t > 0:
        z = torch.randn_like(x_t)
        x_t_minus_1 = mean + torch.sqrt(variance) * z
    else:
        x_t_minus_1 = mean

    return x_t_minus_1


print("✓ Reverse diffusion step implemented")


In [None]:
class LinearSchedule:
    """Linear timestep schedule: uniform sampling across timesteps"""

    def __init__(self, num_timesteps):
        self.num_timesteps = num_timesteps

    def get_timesteps(self, num_steps):
        """Get linearly spaced timesteps"""
        return torch.linspace(0, self.num_timesteps - 1, num_steps).long()


class CosineSchedule:
    """Cosine timestep schedule: emphasize early/mid timesteps"""

    def __init__(self, num_timesteps, s=0.008):
        self.num_timesteps = num_timesteps
        self.s = s

    def get_timesteps(self, num_steps):
        """Get cosine-weighted timesteps"""
        steps = torch.linspace(0, 1, num_steps + 1)
        alphas = torch.cos((steps + self.s) / (1 + self.s) * np.pi * 0.5) ** 2
        alphas = alphas / alphas[0]
        timesteps = (1 - alphas) * (self.num_timesteps - 1)
        return torch.floor(timesteps[1:]).long()


linear_schedule = LinearSchedule(num_timesteps)
cosine_schedule = CosineSchedule(num_timesteps)

print(" Sampling schedules implemented")


In [None]:
def sample_with_schedule(num_samples, schedule, schedule_name):
    """Generate samples using specified timestep schedule"""
    x_t = torch.randn(num_samples, 1, 28, 28, device=device)
    timesteps = schedule.get_timesteps(50)  # 50 steps

    with torch.no_grad():
        for i, t in enumerate(timesteps):
            t_tensor = torch.full((num_samples,), t, dtype=torch.long, device=device)

            # Predict noise
            pred_noise = model(x_t, t_tensor.view(-1))

            # Reverse step
            x_t = reverse_diffusion_step(x_t, t.item(), pred_noise)

    # Clip to valid range
    x_t = torch.clamp(x_t, -1, 1)
    return x_t


# Generate samples with both schedules
print("Generating samples...")
linear_samples = sample_with_schedule(16, linear_schedule, "linear")
cosine_samples = sample_with_schedule(16, cosine_schedule, "cosine")
print("✓ Samples generated")


In [None]:
def plot_samples_grid(samples, title, nrows=4, ncols=4):
    """Visualize sample grid"""
    fig, axes = plt.subplots(nrows, ncols, figsize=(8, 8))

    for idx, ax in enumerate(axes.flat):
        if idx < len(samples):
            img = samples[idx].cpu().squeeze().numpy()
            ax.imshow(img, cmap="gray")
        ax.axis("off")

    plt.suptitle(title, fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()


def compute_image_statistics(samples):
    """Compute mean and std of samples"""
    flat = samples.reshape(samples.size(0), -1)
    return {
        "mean": flat.mean(dim=1).mean().item(),
        "std": flat.std(dim=1).mean().item(),
        "min": samples.min().item(),
        "max": samples.max().item(),
    }


# Visualize both schedules
plot_samples_grid(linear_samples, "Linear Schedule (50 steps)")
plot_samples_grid(cosine_samples, "Cosine Schedule (50 steps)")

# Print statistics
print("\nLinear Schedule Statistics:")
lin_stats = compute_image_statistics(linear_samples)
for key, val in lin_stats.items():
    print(f"  {key}: {val:.4f}")

print("\nCosine Schedule Statistics:")
cos_stats = compute_image_statistics(cosine_samples)
for key, val in cos_stats.items():
    print(f"  {key}: {val:.4f}")


## Section 4: Analysis and Key Insights

### What We Learned

1. **Timestep Schedules Matter**: Different schedules (linear vs cosine) produce different quality outputs
2. **Reverse Process**: Starting from pure noise, we iteratively denoise to generate realistic images
3. **Conditioning**: Both schedules can be enhanced with class labels for conditional generation

### Performance Comparison

- **Linear Schedule**: Uniform timestep coverage - predictable but may waste steps on low-noise regions
- **Cosine Schedule**: Biased toward mid/early timesteps - more efficient allocation of denoising effort

### Next Steps

- Experiment with different numbers of sampling steps (10, 25, 50, 100)
- Try varying schedules (exponential, polynomial)
- Implement classifier-free guidance for better sample control
- Combine with other techniques (VAE-augmented sampling, progressive refinement)