# Lab 2: Mastering the Forward Diffusion Process - Mathematical Implementation
**Course: Diffusion Models: Theory and Applications**  
**Duration: 90 minutes**  
**Team Size: 2 students (same teams from Lab 1)**

---

## Learning Objectives
By the end of this lab, students will be able to:
1. **Implement** the reparameterization trick and understand its crucial role in making diffusion trainable
2. **Build** different noise schedules (linear, cosine, exponential) and analyze their effects
3. **Derive and implement** the forward jump formula from first principles
4. **Create** efficient training data generation pipelines using Gaussian arithmetic
5. **Analyze** signal-to-noise ratios and variance evolution throughout the diffusion process
6. **Compare** computational complexity with and without forward jumps

---

## Prerequisites
- Completion of Lab 1 (basic diffusion implementation)
- Understanding of multivariate Gaussian distributions
- Familiarity with the reparameterization trick concept

---

## Lab Setup and Environment

### Part 1: Team Reunion & Mathematical Setup (10 minutes)

#### Task 1.1: Reconnect with Your Lab Partner
- Same teams as Lab 1 - compare your experiences since last lab
- **Today's Mission**: Build mathematically rigorous diffusion components
- **Success Criteria**: Implement forward jumps that are 1000x faster than sequential corruption

#### Task 1.2: Mathematical Environment Setup

In [None]:
# Mathematical imports for rigorous implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
from typing import Tuple, Optional
import time
from torch.distributions import Normal
from scipy import stats

# Set seeds for reproducible mathematics
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load MNIST for testing our mathematical implementations
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Get a test image for our mathematical experiments
test_image = next(iter(train_loader))[0][:1].to(device)
print(f"Test image shape: {test_image.shape}")
print(f"Test image range: [{test_image.min():.3f}, {test_image.max():.3f}]")

---

## Part 2: Implement the Reparameterization Trick (20 minutes)

### Task 2.1: Understanding Why We Need Reparameterization

**Your Mission**: Implement both sampling approaches and see why one enables gradients while the other doesn't.

In [None]:
class ReparameterizationDemo:
    """
    This class will demonstrate why the reparameterization trick is essential
    for training neural networks with stochastic operations.
    """
    
    def __init__(self):
        # We'll use a simple linear transformation to show gradient flow
        self.linear = nn.Linear(1, 1).to(device)
        
    def direct_sampling_approach(self, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement direct sampling from N(mu, sigma^2)
        
        This approach samples directly from the Gaussian distribution.
        
        Implementation steps:
        1. Create a Normal distribution with given mu and sigma
        2. Sample from it using .sample()
        3. Return the sample
        
        Args:
            mu: Mean of the distribution (can be parameterized by neural network)
            sigma: Standard deviation (should be positive)
            
        Returns:
            sample: A sample from N(mu, sigma^2)
            
        Note: This approach breaks gradient flow!
        """
        # TODO: Implement direct sampling
        # Hint: Use torch.distributions.Normal
        pass
    
    def reparameterized_approach(self, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement the reparameterization trick
        
        This approach separates the randomness from the parameters.
        
        Implementation steps:
        1. Sample epsilon from standard normal: eps ~ N(0, 1)
        2. Transform: z = mu + sigma * eps
        3. Return z
        
        Args:
            mu: Mean (can have gradients)
            sigma: Standard deviation (can have gradients)
            
        Returns:
            sample: Equivalent sample but with gradient flow preserved
        """
        # TODO: Implement reparameterization
        # Hint: Sample eps ~ N(0,1), then transform to z = mu + sigma * eps
        pass
    
    def test_gradient_flow(self):
        """
        Test which approach allows gradient computation
        """
        print("=== Testing Gradient Flow ===\n")
        
        # Create input with gradients
        x = torch.tensor([[1.0]], requires_grad=True, device=device)
        
        # Get parameterized mean
        mu = self.linear(x)
        sigma = torch.tensor(0.5, device=device)  # Fixed sigma for simplicity
        
        print(f"Input x: {x.item():.3f}")
        print(f"Parameterized mu: {mu.item():.3f}")
        print(f"Fixed sigma: {sigma.item():.3f}\n")
        
        # Test direct sampling approach
        try:
            print("Testing direct sampling approach...")
            sample1 = self.direct_sampling_approach(mu, sigma)
            loss1 = sample1.mean()
            print(f"Sample: {sample1.item():.3f}, Loss: {loss1.item():.3f}")
            
            # Try to compute gradients
            self.linear.zero_grad()
            loss1.backward()
            
            # Check if gradients exist
            grad1 = self.linear.weight.grad
            print(f"Gradient for linear.weight: {grad1}")
            
        except Exception as e:
            print(f"Error with direct sampling: {e}")
        
        print()
        
        # Test reparameterized approach
        try:
            print("Testing reparameterized approach...")
            sample2 = self.reparameterized_approach(mu, sigma)
            loss2 = sample2.mean()
            print(f"Sample: {sample2.item():.3f}, Loss: {loss2.item():.3f}")
            
            # Try to compute gradients
            self.linear.zero_grad()
            loss2.backward()
            
            # Check if gradients exist
            grad2 = self.linear.weight.grad
            print(f"Gradient for linear.weight: {grad2}")
            
        except Exception as e:
            print(f"Error with reparameterization: {e}")

# Test statistical equivalence (provided function)
def test_reparameterization_equivalence(reparam_demo, x0, num_trials=1000):
    """Test that both approaches produce statistically equivalent results"""
    print("\n=== Testing Statistical Equivalence ===\n")
    
    mu = torch.zeros_like(x0)
    sigma = torch.ones_like(x0) * 0.5
    
    # Generate samples with both methods
    direct_samples = []
    reparam_samples = []
    
    for _ in range(num_trials):
        try:
            direct_sample = reparam_demo.direct_sampling_approach(mu, sigma)
            direct_samples.append(direct_sample)
        except:
            pass  # Skip if not implemented
            
        try:
            reparam_sample = reparam_demo.reparameterized_approach(mu, sigma)
            reparam_samples.append(reparam_sample)
        except:
            pass  # Skip if not implemented
    
    if direct_samples and reparam_samples:
        direct_samples = torch.stack(direct_samples[:100])  # Use first 100 samples
        reparam_samples = torch.stack(reparam_samples[:100])
        
        # Compare statistics
        direct_mean = direct_samples.mean(dim=0)
        reparam_mean = reparam_samples.mean(dim=0)
        direct_std = direct_samples.std(dim=0)
        reparam_std = reparam_samples.std(dim=0)
        
        print(f"Direct approach - Mean: {direct_mean.mean():.4f}, Std: {direct_std.mean():.4f}")
        print(f"Reparam approach - Mean: {reparam_mean.mean():.4f}, Std: {reparam_std.mean():.4f}")
        print(f"Mean difference: {(direct_mean - reparam_mean).abs().mean():.6f}")
        print(f"Std difference: {(direct_std - reparam_std).abs().mean():.6f}")
        
        # Visualize distributions
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        ax1.hist(direct_samples.flatten().cpu(), bins=50, alpha=0.7, label='Direct', density=True)
        ax1.hist(reparam_samples.flatten().cpu(), bins=50, alpha=0.7, label='Reparameterized', density=True)
        ax1.set_title('Sample Distributions')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Q-Q plot for distribution comparison
        sample_data = reparam_samples.flatten().cpu().numpy()
        stats.probplot(sample_data, dist="norm", plot=ax2)
        ax2.set_title('Q-Q Plot: Reparameterized Samples vs Normal')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Test your implementation
demo = ReparameterizationDemo()
demo.test_gradient_flow()

# Test statistical equivalence (uncomment after implementing both approaches)
# test_reparameterization_equivalence(demo, test_image)

### Task 2.2: Implement Diffusion-Specific Reparameterization

In [None]:
class DiffusionReparameterization:
    """
    Implement reparameterization specifically for diffusion steps.
    This class implements the core mathematical transformation that makes
    diffusion models trainable.
    """
    
    def forward_step_sampling(self, x_prev: torch.Tensor, beta_t: float) -> torch.Tensor:
        """
        TODO: Implement the WRONG way (direct sampling)
        
        Direct sampling from: q(x_t | x_{t-1}) = N(sqrt(1-beta_t) * x_{t-1}, beta_t * I)
        
        This approach will not allow gradients to flow through the sampling operation.
        
        Steps:
        1. Compute mean: sqrt(1-beta_t) * x_prev
        2. Compute std: sqrt(beta_t)
        3. Create Normal distribution and sample
        
        Args:
            x_prev: Previous state x_{t-1}
            beta_t: Noise schedule value at timestep t
            
        Returns:
            x_t: Next state (without gradient flow)
        """
        # TODO: Implement direct sampling
        # Hint: Use math.sqrt and torch.distributions.Normal
        pass
    
    def forward_step_reparameterized(self, x_prev: torch.Tensor, beta_t: float) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        TODO: Implement the RIGHT way (reparameterization trick)
        
        Reparameterized form: x_t = sqrt(1-beta_t) * x_{t-1} + sqrt(beta_t) * epsilon
        where epsilon ~ N(0, I)
        
        This preserves gradient flow and is the foundation of trainable diffusion.
        
        Steps:
        1. Sample epsilon from standard normal
        2. Compute signal coefficient: sqrt(1-beta_t)
        3. Compute noise coefficient: sqrt(beta_t)
        4. Apply transformation: signal * x_prev + noise * epsilon
        5. Return both x_t and the epsilon used (important for training!)
        
        Args:
            x_prev: Previous state x_{t-1}
            beta_t: Noise schedule value at timestep t
            
        Returns:
            x_t: Next state (with gradient flow)
            epsilon: The noise that was added (needed for training targets)
        """
        # TODO: Implement reparameterized version
        # Hint: epsilon = torch.randn_like(x_prev), then apply the transformation
        pass
    
    def compare_approaches(self, x_input: torch.Tensor, beta_t: float = 0.01):
        """
        Compare both approaches and verify they produce same statistics
        """
        print(f"=== Comparing Sampling Approaches (beta_t = {beta_t}) ===\n")
        
        # Theoretical values
        theoretical_mean = math.sqrt(1 - beta_t) * x_input
        theoretical_var = beta_t * torch.ones_like(x_input)
        
        print(f"Theoretical mean: {theoretical_mean.mean().item():.4f}")
        print(f"Theoretical variance: {theoretical_var.mean().item():.4f}\n")
        
        # Test approaches if implemented
        try:
            direct_result = self.forward_step_sampling(x_input, beta_t)
            print(f"Direct sampling result shape: {direct_result.shape}")
            print(f"Direct sampling mean: {direct_result.mean().item():.4f}")
        except:
            print("Direct sampling not yet implemented")
        
        try:
            reparam_result, epsilon = self.forward_step_reparameterized(x_input, beta_t)
            print(f"Reparameterized result shape: {reparam_result.shape}")
            print(f"Reparameterized mean: {reparam_result.mean().item():.4f}")
            print(f"Epsilon shape: {epsilon.shape}")
        except:
            print("Reparameterized approach not yet implemented")

# Test diffusion reparameterization
diffusion_reparam = DiffusionReparameterization()
diffusion_reparam.compare_approaches(test_image, beta_t=0.1)

---

## Part 3: Implement and Compare Noise Schedules (25 minutes)

### Task 3.1: Build Different Noise Schedule Types

**Your Mission**: Implement the three main types of noise schedules and understand their mathematical properties.

In [None]:
class NoiseScheduler:
    """
    Implement mathematically rigorous noise schedules.
    This class generates the beta and alpha schedules that control
    how aggressively we corrupt data at each timestep.
    """
    
    def __init__(self, num_timesteps: int = 1000, device: str = 'cpu'):
        self.num_timesteps = num_timesteps
        self.device = device
        
        # Will store our schedules
        self.betas = None
        self.alphas = None
        self.alpha_cumprod = None
        self.sqrt_alpha_cumprod = None
        self.sqrt_one_minus_alpha_cumprod = None
        
    def linear_schedule(self, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor:
        """
        TODO: Implement linear noise schedule
        
        Formula: beta_t = beta_start + (t-1)/(T-1) * (beta_end - beta_start)
        
        This creates a linearly increasing schedule from beta_start to beta_end.
        
        Steps:
        1. Create timestep indices from 0 to T-1
        2. Apply linear interpolation formula
        3. Ensure all betas are in valid range (0, 1)
        
        Args:
            beta_start: Starting noise level (should be very small)
            beta_end: Ending noise level (should be substantial but < 1)
            
        Returns:
            betas: Tensor of shape (num_timesteps,) with beta values
        """
        # TODO: Implement linear schedule
        # Hint: Use torch.arange and linear interpolation
        pass
    
    def cosine_schedule(self, s: float = 0.008) -> torch.Tensor:
        """
        TODO: Implement cosine noise schedule
        
        This is the schedule from "Improved Denoising Diffusion Probabilistic Models".
        It provides more gradual corruption early and faster corruption later.
        
        Formula is complex but gives better training dynamics:
        1. Define f(t) = cos((t/T + s)/(1 + s) * π/2)^2
        2. alpha_cumprod_t = f(t) / f(0)
        3. alpha_t = alpha_cumprod_t / alpha_cumprod_{t-1}
        4. beta_t = 1 - alpha_t
        
        Steps:
        1. Create timestep array including 0 to T
        2. Compute f(t) values using cosine function
        3. Normalize by f(0)
        4. Compute alphas from cumulative products
        5. Convert to betas: beta_t = 1 - alpha_t
        6. Clip to prevent numerical issues
        
        Args:
            s: Small offset to prevent beta_t = 0 at t = 0
            
        Returns:
            betas: Tensor of shape (num_timesteps,) with beta values
        """
        # TODO: Implement cosine schedule
        # This is more challenging - research the exact formula!
        # Hint: Use torch.cos and handle the cumulative product carefully
        pass
    
    def exponential_schedule(self, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor:
        """
        TODO: Implement exponential noise schedule
        
        This schedule increases exponentially, giving very gentle corruption early
        and very aggressive corruption later.
        
        Formula: beta_t = exp(log(beta_start) + t/(T-1) * log(beta_end/beta_start))
        
        Steps:
        1. Create timestep indices
        2. Compute log-space interpolation
        3. Exponentiate to get betas
        
        Args:
            beta_start: Starting noise level
            beta_end: Ending noise level
            
        Returns:
            betas: Tensor of shape (num_timesteps,) with beta values
        """
        # TODO: Implement exponential schedule
        # Hint: Use torch.log, torch.exp, and log-space interpolation
        pass
    
    def precompute_schedule(self, schedule_type: str = "linear", **kwargs):
        """
        TODO: Precompute all derived quantities for efficient forward jumps
        
        Once we have betas, we need to compute all the derived quantities:
        - alphas = 1 - betas
        - alpha_cumprod = cumulative product of alphas
        - sqrt_alpha_cumprod = square root for signal coefficient
        - sqrt_one_minus_alpha_cumprod = square root for noise coefficient
        
        These precomputed values enable O(1) forward jumps.
        
        Steps:
        1. Get betas using specified schedule
        2. Compute alphas = 1 - betas
        3. Compute cumulative product: alpha_cumprod[t] = product of alphas[0:t+1]
        4. Compute square roots for forward jump formula
        5. Store all quantities for later use
        
        Args:
            schedule_type: "linear", "cosine", or "exponential"
            **kwargs: Additional parameters for specific schedules
        """
        # Get betas using specified schedule
        if schedule_type == "linear":
            self.betas = self.linear_schedule(**kwargs)
        elif schedule_type == "cosine":
            self.betas = self.cosine_schedule(**kwargs)
        elif schedule_type == "exponential":
            self.betas = self.exponential_schedule(**kwargs)
        else:
            raise ValueError(f"Unknown schedule type: {schedule_type}")
        
        # TODO: Compute derived quantities
        # Hint: alphas = 1 - betas, then use torch.cumprod for cumulative product
        pass
        
        if self.betas is not None:
            print(f"Precomputed {schedule_type} schedule:")
            print(f"  Beta range: [{self.betas.min():.6f}, {self.betas.max():.6f}]")
            if self.alpha_cumprod is not None:
                print(f"  Alpha cumprod range: [{self.alpha_cumprod.min():.6f}, {self.alpha_cumprod.max():.6f}]")

# Visualization function (provided)
def plot_noise_schedules(schedulers):
    """Complete plotting function for schedule comparison"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    # Plot 1: Beta schedules
    for stype, scheduler in schedulers.items():
        if scheduler.betas is not None:
            axes[0, 0].plot(scheduler.betas.cpu(), label=stype, linewidth=2)
    axes[0, 0].set_title('Beta Schedules')
    axes[0, 0].set_xlabel('Timestep')
    axes[0, 0].set_ylabel('Beta_t')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Alpha cumulative product
    for stype, scheduler in schedulers.items():
        if scheduler.alpha_cumprod is not None:
            axes[0, 1].plot(scheduler.alpha_cumprod.cpu(), label=stype, linewidth=2)
    axes[0, 1].set_title('Alpha Cumulative Product')
    axes[0, 1].set_xlabel('Timestep')
    axes[0, 1].set_ylabel('Alpha_cumprod')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: Signal coefficient
    for stype, scheduler in schedulers.items():
        if scheduler.sqrt_alpha_cumprod is not None:
            axes[0, 2].plot(scheduler.sqrt_alpha_cumprod.cpu(), label=stype, linewidth=2)
    axes[0, 2].set_title('Signal Coefficient: √(alpha_cumprod)')
    axes[0, 2].set_xlabel('Timestep')
    axes[0, 2].set_ylabel('Signal Coefficient')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Plot 4: Noise coefficient
    for stype, scheduler in schedulers.items():
        if scheduler.sqrt_one_minus_alpha_cumprod is not None:
            axes[1, 0].plot(scheduler.sqrt_one_minus_alpha_cumprod.cpu(), label=stype, linewidth=2)
    axes[1, 0].set_title('Noise Coefficient: √(1-alpha_cumprod)')
    axes[1, 0].set_xlabel('Timestep')
    axes[1, 0].set_ylabel('Noise Coefficient')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 5: Signal-to-noise ratio
    for stype, scheduler in schedulers.items():
        if scheduler.alpha_cumprod is not None:
            snr = scheduler.alpha_cumprod / (1 - scheduler.alpha_cumprod + 1e-8)
            axes[1, 1].plot(snr.cpu(), label=stype, linewidth=2)
    axes[1, 1].set_title('Signal-to-Noise Ratio')
    axes[1, 1].set_xlabel('Timestep')
    axes[1, 1].set_ylabel('SNR')
    axes[1, 1].set_yscale('log')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Plot 6: Noise schedule comparison
    timesteps = range(1000)
    for stype, scheduler in schedulers.items():
        if scheduler.alpha_cumprod is not None:
            # Plot percentage of original signal remaining
            signal_remaining = scheduler.alpha_cumprod.cpu() * 100
            axes[1, 2].plot(signal_remaining, label=f'{stype} (signal %)', linewidth=2)
    axes[1, 2].set_title('Signal Preservation Over Time')
    axes[1, 2].set_xlabel('Timestep')
    axes[1, 2].set_ylabel('% Original Signal')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Test all three schedules
scheduler = NoiseScheduler(num_timesteps=1000, device=device)

print("Testing different noise schedules...\n")
schedulers = {}

for schedule_type in ["linear", "cosine", "exponential"]:
    test_scheduler = NoiseScheduler(num_timesteps=1000, device=device)
    test_scheduler.precompute_schedule(schedule_type)
    schedulers[schedule_type] = test_scheduler
    print()

# Plot comparison (uncomment after implementing schedules)
# plot_noise_schedules(schedulers)

### Task 3.2: Analyze Schedule Differences

In [None]:
def analyze_noise_schedules(schedulers):
    """
    Comprehensive analysis of different noise schedules
    """
    print("=== Schedule Analysis ===")
    
    for stype, scheduler in schedulers.items():
        if scheduler.alpha_cumprod is not None:
            print(f"\n{stype.upper()} Schedule:")
            
            # Find key milestones
            alpha_cumprod = scheduler.alpha_cumprod.cpu()
            
            # At what timestep does signal drop below 50%?
            signal_50_idx = torch.where(alpha_cumprod < 0.5)[0]
            if len(signal_50_idx) > 0:
                print(f"  Signal drops below 50% at timestep: {signal_50_idx[0].item() + 1}")
            
            # At what timestep is SNR = 1 (equal signal and noise)?
            snr = alpha_cumprod / (1 - alpha_cumprod)
            snr_1_idx = torch.where(snr < 1.0)[0]
            if len(snr_1_idx) > 0:
                print(f"  SNR drops below 1.0 at timestep: {snr_1_idx[0].item() + 1}")
            
            # Rate of corruption (how quickly signal drops)
            signal_drop_rate = -torch.diff(alpha_cumprod).mean()
            print(f"  Average signal drop rate: {signal_drop_rate:.6f} per timestep")
            
            # Final corruption level
            final_signal = alpha_cumprod[-1]
            print(f"  Final signal level: {final_signal:.6f} ({final_signal*100:.3f}%)")

# Run analysis (uncomment after implementing schedules)
# analyze_noise_schedules(schedulers)

---

## Part 4: Implement and Derive Forward Jumps (25 minutes)

### Task 4.1: Sequential vs Direct Implementation

**Your Mission**: Implement both approaches and measure the dramatic computational difference.

In [None]:
class ForwardJumpImplementation:
    """
    Implement both sequential and direct forward jump approaches.
    This class demonstrates the computational revolution that makes
    diffusion models practical.
    """
    
    def __init__(self, scheduler: NoiseScheduler):
        self.scheduler = scheduler
    
    def sequential_forward_process(self, x0: torch.Tensor, target_timestep: int) -> Tuple[torch.Tensor, list]:
        """
        TODO: Implement the SLOW way - sequential application
        
        This simulates the actual Markov chain by applying each step sequentially.
        This is how you would have to do it without the forward jump property.
        
        Steps:
        1. Start with x0
        2. For each timestep from 1 to target_timestep:
           a. Get beta_t for current timestep
           b. Apply reparameterized step: x_t = sqrt(1-beta_t) * x_{t-1} + sqrt(beta_t) * eps
           c. Store intermediate result
        3. Return final result and all intermediate states
        
        Args:
            x0: Clean starting image
            target_timestep: How many steps to apply
            
        Returns:
            x_final: Final corrupted image after target_timestep steps
            intermediates: List of all intermediate states (for visualization)
        """
        # TODO: Implement sequential process
        # Hint: Use a loop from 1 to target_timestep, apply reparameterized steps
        current = x0.clone()
        intermediates = [current.clone()]
        
        # TODO: Implement the sequential loop here
        
        return current, intermediates
    
    def direct_forward_jump(self, x0: torch.Tensor, target_timestep: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        TODO: Implement the FAST way - direct jump using precomputed coefficients
        
        This uses the mathematical property that multiple Gaussian steps
        can be collapsed into a single equivalent step.
        
        Formula: x_t = sqrt(alpha_cumprod_t) * x0 + sqrt(1 - alpha_cumprod_t) * epsilon
        
        Steps:
        1. Get precomputed coefficients for target timestep
        2. Sample fresh noise
        3. Apply single transformation
        4. Return result and noise used
        
        Args:
            x0: Clean starting image  
            target_timestep: Target corruption level
            
        Returns:
            x_t: Corrupted image at target timestep
            epsilon: The noise that was added
        """
        # TODO: Implement direct forward jump
        if target_timestep == 0:
            return x0, torch.zeros_like(x0)
        
        # TODO: Get precomputed coefficients and apply the forward jump formula
        # Hint: Use self.scheduler.sqrt_alpha_cumprod and self.scheduler.sqrt_one_minus_alpha_cumprod
        pass

# Benchmarking function (provided)
def benchmark_forward_approaches(forward_impl, x0, timesteps_to_test):
    """Complete benchmarking with timing and visualization"""
    print("=== Performance Benchmark ===\n")
    
    results = {'sequential': [], 'direct': [], 'speedups': [], 'timesteps': []}
    
    for target_t in timesteps_to_test:
        if target_t > forward_impl.scheduler.num_timesteps:
            continue
            
        print(f"Testing timestep {target_t}:")
        
        # Time sequential approach
        start_time = time.time()
        try:
            x_seq, _ = forward_impl.sequential_forward_process(x0, target_t)
            seq_time = time.time() - start_time
        except:
            print("  Sequential approach not implemented")
            continue
        
        # Time direct approach  
        start_time = time.time()
        try:
            x_direct, _ = forward_impl.direct_forward_jump(x0, target_t)
            direct_time = time.time() - start_time
        except:
            print("  Direct approach not implemented")
            continue
        
        # Compare results
        mse = torch.mean((x_seq - x_direct) ** 2).item()
        speedup = seq_time / direct_time if direct_time > 0 else float('inf')
        
        results['sequential'].append(seq_time)
        results['direct'].append(direct_time)
        results['speedups'].append(speedup)
        results['timesteps'].append(target_t)
        
        print(f"  Sequential time: {seq_time:.4f}s")
        print(f"  Direct time: {direct_time:.4f}s") 
        print(f"  Speedup: {speedup:.1f}x")
        print(f"  MSE difference: {mse:.6f}")
        print()
    
    if results['timesteps']:
        # Plot results
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        ax1.plot(results['timesteps'], results['sequential'], 'r-o', label='Sequential', linewidth=2)
        ax1.plot(results['timesteps'], results['direct'], 'b-o', label='Direct', linewidth=2)
        ax1.set_xlabel('Timestep')
        ax1.set_ylabel('Time (seconds)')
        ax1.set_title('Computation Time Comparison')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        ax1.set_yscale('log')
        
        ax2.plot(results['timesteps'], results['speedups'], 'g-o', linewidth=2, markersize=8)
        ax2.set_xlabel('Timestep')
        ax2.set_ylabel('Speedup Factor')
        ax2.set_title('Direct vs Sequential Speedup')
        ax2.grid(True, alpha=0.3)
        ax2.axhline(y=1, color='r', linestyle='--', alpha=0.5, label='No speedup')
        ax2.legend()
        
        plt.tight_layout()
        plt.show()
        
        print(f"Average speedup: {np.mean(results['speedups']):.1f}x")
        print(f"Maximum speedup: {max(results['speedups']):.1f}x")
    
    return results

# Visualization function for corruption process (provided)
def visualize_corruption_process(forward_impl, x0, timesteps_to_show):
    """Visualize corruption at different timesteps"""
    fig, axes = plt.subplots(2, len(timesteps_to_show), figsize=(3*len(timesteps_to_show), 6))
    
    if len(timesteps_to_show) == 1:
        axes = axes.reshape(-1, 1)
    
    for i, t in enumerate(timesteps_to_show):
        if t == 0:
            corrupted = x0
            noise_level = 0.0
        else:
            try:
                corrupted, epsilon = forward_impl.direct_forward_jump(x0, t)
                if forward_impl.scheduler.alpha_cumprod is not None:
                    signal_level = forward_impl.scheduler.alpha_cumprod[t-1].item()
                    noise_level = 1 - signal_level
                else:
                    noise_level = t / 1000  # Rough estimate
            except:
                corrupted = x0
                noise_level = 0.0
        
        # Display original and corrupted
        axes[0, i].imshow(x0[0, 0].cpu(), cmap='gray')
        axes[0, i].set_title(f'Original')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(corrupted[0, 0].cpu(), cmap='gray')
        axes[1, i].set_title(f't={t}\nNoise: {noise_level:.2f}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Test forward jump implementation (use scheduler from previous section)
# First ensure a scheduler is set up
if 'scheduler' not in locals() or scheduler.betas is None:
    scheduler = NoiseScheduler(num_timesteps=1000, device=device)
    scheduler.precompute_schedule("linear")

forward_jump = ForwardJumpImplementation(scheduler)

# Visualize corruption process (uncomment after implementing direct_forward_jump)
timesteps_to_show = [0, 100, 300, 500, 700, 999]
# visualize_corruption_process(forward_jump, test_image, timesteps_to_show)

# Benchmark performance (uncomment after implementing both approaches)
timesteps_to_test = [10, 50, 100, 500, 1000]
# results = benchmark_forward_approaches(forward_jump, test_image, timesteps_to_test)

### Task 4.2: Mathematical Derivation Implementation

In [None]:
class ForwardJumpDerivation:
    """
    Implement step-by-step mathematical derivation.
    This class helps you understand exactly how the forward jump formula
    emerges from Gaussian arithmetic.
    """
    
    def two_step_expansion(self, x0: torch.Tensor, beta1: float, beta2: float) -> dict:
        """
        TODO: Manually expand the first two steps to see the pattern
        
        Starting from:
        x1 = sqrt(1-beta1) * x0 + sqrt(beta1) * eps0
        x2 = sqrt(1-beta2) * x1 + sqrt(beta2) * eps1
        
        Expand x2 by substituting x1, then use Gaussian arithmetic to
        combine the noise terms.
        
        Steps:
        1. Substitute x1 into x2 equation
        2. Expand and collect terms
        3. Identify signal coefficient for x0
        4. Identify combined noise coefficient
        5. Use Gaussian arithmetic: sqrt(a)*eps1 + sqrt(b)*eps2 = sqrt(a+b)*eps
        
        Args:
            x0: Starting clean image
            beta1: First timestep noise level
            beta2: Second timestep noise level
            
        Returns:
            dict: Contains analytical and computed results for comparison
        """
        print(f"=== Two-Step Mathematical Derivation ===")
        print(f"beta1 = {beta1:.4f}, beta2 = {beta2:.4f}\n")
        
        # TODO: Step 1 - Compute x1 using the single-step formula
        # Hint: x1 = sqrt(1-beta1) * x0 + sqrt(beta1) * eps0
        
        # TODO: Step 2 - Substitute x1 into the x2 equation and expand
        # Hint: Replace x1 in the x2 formula and collect terms
        
        # TODO: Step 3 - Apply Gaussian arithmetic to combine noise terms
        # Hint: Independent Gaussians add their variances
        
        # TODO: Step 4 - Verify your result matches the theoretical formula
        # Hint: Should equal sqrt(alpha1*alpha2) * x0 + sqrt(1-alpha1*alpha2) * eps
        
        # Provided: Show the mathematical steps for reference
        alpha1 = 1 - beta1
        alpha2 = 1 - beta2
        
        print(f"Mathematical structure:")
        print(f"  x1 = sqrt({alpha1:.4f}) * x0 + sqrt({beta1:.4f}) * eps0")
        print(f"  x2 = sqrt({alpha2:.4f}) * x1 + sqrt({beta2:.4f}) * eps1")
        print(f"  After substitution, identify the coefficients...")
        
        combined_noise_var = alpha2 * beta1 + beta2
        theoretical_noise_var = 1 - alpha1 * alpha2
        
        print(f"\nGaussian arithmetic:")
        print(f"  Combined noise variance = {alpha2*beta1:.4f} + {beta2:.4f} = {combined_noise_var:.4f}")
        print(f"  Theoretical variance = 1 - {alpha1*alpha2:.4f} = {theoretical_noise_var:.4f}")
        print(f"  Match: {abs(combined_noise_var - theoretical_noise_var) < 1e-10}")
        
        # TODO: Implement your derivation and return results
        results = {
            'signal_coeff_expanded': None,    # TODO: Compute from your expansion
            'signal_coeff_direct': None,      # TODO: Compute from direct formula
            'noise_var_expanded': None,       # TODO: Compute combined noise variance
            'noise_var_theoretical': None,    # TODO: Compute theoretical variance
            'coefficients_match': None,       # TODO: Check if they match
        }
        
        return results Gaussian arithmetic to combine noise terms
        # Combined variance: alpha2 * beta1 + beta2
        # combined_noise_var = alpha2 * beta1 + beta2
        # combined_noise_std = math.sqrt(combined_noise_var)
        
        # TODO: Step 4 - Verify this equals 1 - alpha1*alpha2
        # theoretical_noise_var = 1 - alpha1 * alpha2
        # theoretical_noise_std = math.sqrt(theoretical_noise_var)
        
        # TODO: Step 5 - Create direct formula version
        # alpha_cumprod_2 = alpha1 * alpha2
        # eps_combined = torch.randn_like(x0)
        # x2_direct = math.sqrt(alpha_cumprod_2) * x0 + math.sqrt(1 - alpha_cumprod_2) * eps_combined
        
        # Show the mathematical steps
        alpha1 = 1 - beta1
        alpha2 = 1 - beta2
        
        print(f"Step-by-step expansion:")
        print(f"  x1 = sqrt({alpha1:.4f}) * x0 + sqrt({beta1:.4f}) * eps0")
        print(f"  x2 = sqrt({alpha2:.4f}) * x1 + sqrt({beta2:.4f}) * eps1")
        print(f"  x2 = sqrt({alpha2:.4f}) * [sqrt({alpha1:.4f}) * x0 + sqrt({beta1:.4f}) * eps0] + sqrt({beta2:.4f}) * eps1")
        print(f"  x2 = sqrt({alpha1*alpha2:.4f}) * x0 + sqrt({alpha2*beta1:.4f}) * eps0 + sqrt({beta2:.4f}) * eps1")
        
        combined_noise_var = alpha2 * beta1 + beta2
        theoretical_noise_var = 1 - alpha1 * alpha2
        
        print(f"\nGaussian arithmetic:")
        print(f"  Combined noise variance = {alpha2*beta1:.4f} + {beta2:.4f} = {combined_noise_var:.4f}")
        print(f"  Theoretical variance = 1 - {alpha1*alpha2:.4f} = {theoretical_noise_var:.4f}")
        print(f"  Match: {abs(combined_noise_var - theoretical_noise_var) < 1e-10}")
        
        results = {
            'signal_coeff_expanded': math.sqrt(alpha1 * alpha2),
            'signal_coeff_direct': math.sqrt(alpha1 * alpha2),
            'noise_var_expanded': combined_noise_var,
            'noise_var_theoretical': theoretical_noise_var,
            'coefficients_match': abs(combined_noise_var - theoretical_noise_var) < 1e-10,
        }
        
        return results
    
    def general_derivation_proof(self, num_steps: int = 5):
        """
        TODO: Prove the general case by induction
        
        Show that if the formula holds for n steps, it also holds for n+1 steps.
        This proves the forward jump formula for any number of timesteps.
        
        Base case: x1 = sqrt(alpha1) * x0 + sqrt(1-alpha1) * eps
        Inductive step: If x_n = sqrt(alpha_cumprod_n) * x0 + sqrt(1-alpha_cumprod_n) * eps
                       Then x_{n+1} = sqrt(alpha_cumprod_{n+1}) * x0 + sqrt(1-alpha_cumprod_{n+1}) * eps
        
        Args:
            num_steps: Number of steps to verify the pattern
        """
        print(f"=== General Derivation Proof ===\n")
        
        betas = torch.tensor([0.01, 0.02, 0.03, 0.04, 0.05])[:num_steps]
        
        for n in range(1, num_steps + 1):
            print(f"Step {n}:")
            
            # TODO: Compute alpha_cumprod for n steps and verify properties
            # Hint: Use torch.prod for cumulative product
            
            # Show what the pattern should be
            alphas = 1 - betas[:n]
            alpha_cumprod_n = torch.prod(alphas)
            
            signal_coeff_sq = alpha_cumprod_n
            noise_coeff_sq = 1 - alpha_cumprod_n
            total_variance = signal_coeff_sq + noise_coeff_sq
            
            print(f"  Alpha cumprod: {alpha_cumprod_n:.6f}")
            print(f"  Signal coeff²: {signal_coeff_sq:.6f}")
            print(f"  Noise coeff²: {noise_coeff_sq:.6f}")
            print(f"  Total variance: {total_variance:.6f} (should be 1.0)")
            print()

# Test mathematical derivation
derivation = ForwardJumpDerivation()

# Test two-step case
results = derivation.two_step_expansion(test_image, beta1=0.01, beta2=0.02)
print(f"\nResults summary:")
for key, value in results.items():
    print(f"  {key}: {value}")

# Test general case
derivation.general_derivation_proof(num_steps=5)

---

## Part 5: Efficient Training Data Generation (15 minutes)

### Task 5.1: Build Production-Ready Training Pipeline

In [None]:
class DiffusionTrainingDataGenerator:
    """
    Implement efficient training data generation pipeline.
    This is the core component that enables practical diffusion model training.
    It must be fast, memory-efficient, and mathematically correct.
    """
    
    def __init__(self, scheduler: NoiseScheduler):
        self.scheduler = scheduler
        
    def generate_training_sample(self, x0: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        TODO: Generate a single training sample using random timestep sampling
        
        This is the core training data generation function that creates
        unlimited training data from any clean image.
        
        Steps:
        1. Sample random timestep t from [1, T]
        2. Sample random noise epsilon ~ N(0, I)
        3. Apply forward jump: x_t = sqrt(alpha_cumprod_t) * x0 + sqrt(1-alpha_cumprod_t) * epsilon
        4. Return (x_t, t, epsilon) as training triple
        
        The neural network will learn to predict epsilon given (x_t, t).
        
        Args:
            x0: Clean image(s) - shape (batch_size, channels, height, width)
            
        Returns:
            x_t: Noisy image at random timestep
            t: The timestep that was sampled
            epsilon: The noise that was added (training target)
        """
        batch_size = x0.shape[0]
        
        # TODO: Sample random timesteps for each image in batch
        # Hint: Use torch.randint(1, self.scheduler.num_timesteps + 1, (batch_size,), device=x0.device)
        
        # TODO: Sample noise
        # Hint: Use torch.randn_like(x0)
        
        # TODO: Get coefficients for sampled timesteps (convert to 0-indexed)
        # Hint: Use self.scheduler.sqrt_alpha_cumprod and self.scheduler.sqrt_one_minus_alpha_cumprod
        
        # TODO: Reshape coefficients for broadcasting with image dimensions
        # Hint: Use .view(-1, 1, 1, 1) for 4D tensors
        
        # TODO: Apply forward jump formula
        # Hint: x_t = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * epsilon
        
        # return x_t, t, epsilon
        pass

# Analysis functions (provided)
def analyze_training_distribution(generator, test_image, num_samples=5000):
    """Analyze the distribution of training data"""
    print(f"=== Training Data Distribution Analysis ===\n")
    
    all_timesteps = []
    all_snrs = []
    
    with torch.no_grad():
        for _ in range(num_samples // 100):  # Generate in batches
            try:
                batch_x0 = test_image.repeat(100, 1, 1, 1)
                x_t, t, epsilon = generator.generate_training_sample(batch_x0)
                
                all_timesteps.extend(t.cpu().tolist())
                
                # Compute SNR for each sample
                for i in range(len(t)):
                    t_idx = t[i] - 1
                    if generator.scheduler.alpha_cumprod is not None:
                        alpha_cumprod = generator.scheduler.alpha_cumprod[t_idx]
                        snr = alpha_cumprod / (1 - alpha_cumprod)
                        all_snrs.append(snr.item())
                    
            except Exception as e:
                print(f"Training sample generation not implemented: {e}")
                return
    
    if all_timesteps:
        # Plot timestep distribution
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].hist(all_timesteps, bins=50, alpha=0.7, density=True, color='blue')
        axes[0].set_title('Timestep Distribution\n(Should be uniform)')
        axes[0].set_xlabel('Timestep')
        axes[0].set_ylabel('Density')
        axes[0].grid(True, alpha=0.3)
        
        if all_snrs:
            axes[1].hist(all_snrs, bins=50, alpha=0.7, density=True, color='green')
            axes[1].set_title('Signal-to-Noise Ratio Distribution')
            axes[1].set_xlabel('SNR')
            axes[1].set_ylabel('Density')
            axes[1].set_yscale('log')
            axes[1].grid(True, alpha=0.3)
            
            # SNR vs timestep scatter plot
            timestep_snr_pairs = list(zip(all_timesteps[:1000], all_snrs[:1000]))
            timesteps_plot, snrs_plot = zip(*timestep_snr_pairs)
            axes[2].scatter(timesteps_plot, snrs_plot, alpha=0.5, s=1, color='red')
            axes[2].set_title('SNR vs Timestep')
            axes[2].set_xlabel('Timestep')
            axes[2].set_ylabel('SNR')
            axes[2].set_yscale('log')
            axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print statistics
        print(f"Timestep statistics:")
        print(f"  Mean: {np.mean(all_timesteps):.2f} (should be ~{generator.scheduler.num_timesteps/2})")
        print(f"  Std: {np.std(all_timesteps):.2f}")
        print(f"  Range: [{min(all_timesteps)}, {max(all_timesteps)}]")
        
        if all_snrs:
            print(f"\nSNR statistics:")
            print(f"  Mean: {np.mean(all_snrs):.4f}")
            print(f"  Range: [{min(all_snrs):.6f}, {max(all_snrs):.2f}]")

def benchmark_generation_speed(generator, test_image, batch_sizes=[1, 32, 128, 512]):
    """Benchmark training data generation speed"""
    print("=== Training Data Generation Benchmark ===\n")
    
    for batch_size in batch_sizes:
        # Create test batch
        x0_batch = test_image.repeat(batch_size, 1, 1, 1)
        
        try:
            # Warmup
            for _ in range(5):
                generator.generate_training_sample(x0_batch)
            
            # Benchmark
            num_iterations = 100
            start_time = time.time()
            
            with torch.no_grad():
                for _ in range(num_iterations):
                    x_t, t, epsilon = generator.generate_training_sample(x0_batch)
            
            end_time = time.time()
            
            total_samples = batch_size * num_iterations
            samples_per_second = total_samples / (end_time - start_time)
            
            print(f"Batch size {batch_size:3d}: {samples_per_second:8.1f} samples/sec")
            
        except Exception as e:
            print(f"Batch size {batch_size:3d}: Not implemented ({e})")

# Test training data generation
if 'scheduler' not in locals() or scheduler.alpha_cumprod is None:
    scheduler = NoiseScheduler(num_timesteps=1000, device=device)
    scheduler.precompute_schedule("linear")

generator = DiffusionTrainingDataGenerator(scheduler)

# Test single sample generation (uncomment after implementing)
try:
    x_t, t, epsilon = generator.generate_training_sample(test_image)
    print(f"Generated training sample:")
    print(f"  Input shape: {test_image.shape}")
    print(f"  Output x_t shape: {x_t.shape}")
    print(f"  Timestep: {t.item()}")
    print(f"  Epsilon shape: {epsilon.shape}")
except:
    print("Training sample generation not yet implemented")

# Analyze training distribution (uncomment after implementing)
# analyze_training_distribution(generator, test_image, num_samples=5000)

# Benchmark speed (uncomment after implementing)
# benchmark_generation_speed(generator, test_image)

---

## Part 6: Comprehensive Testing and Validation (10 minutes)

### Task 6.1: Validate Mathematical Correctness

In [None]:
def comprehensive_validation_suite():
    """
    Comprehensive validation of all components
    
    This function tests that all mathematical implementations are correct
    and match theoretical expectations.
    """
    print("=== Comprehensive Validation Suite ===\n")
    
    # Setup for testing
    if 'scheduler' not in locals() or scheduler.alpha_cumprod is None:
        test_scheduler = NoiseScheduler(num_timesteps=1000, device=device)
        test_scheduler.precompute_schedule("linear")
    else:
        test_scheduler = scheduler
    
    # Test 1: Variance preservation
    print("Test 1: Variance Preservation")
    x0 = torch.randn(100, 1, 28, 28, device=device)  # Batch of random images
    
    forward_impl = ForwardJumpImplementation(test_scheduler)
    
    for t in [1, 100, 500, 999]:
        try:
            x_t, epsilon = forward_impl.direct_forward_jump(x0, t)
            
            # Check that variance is preserved
            original_var = torch.var(x0)
            corrupted_var = torch.var(x_t)
            
            print(f"  Timestep {t}: Original var = {original_var:.4f}, Corrupted var = {corrupted_var:.4f}")
            if abs(original_var - corrupted_var) > 0.1:
                print(f"    ⚠️ Warning: Variance not well preserved!")
            else:
                print(f"    ✓ Variance preserved")
        except:
            print(f"  Timestep {t}: Forward jump not implemented")
    
    print()
    
    # Test 2: Boundary conditions
    print("Test 2: Boundary Conditions")
    
    try:
        # Test t=0 case
        x_0, eps_0 = forward_impl.direct_forward_jump(x0, 0)
        if torch.allclose(x_0, x0):
            print("  ✓ t=0 returns original image")
        else:
            print("  ❌ t=0 should return original image")
    except:
        print("  t=0 test: Not implemented")
    
    try:
        # Test t=T case (should be mostly noise)
        x_T, eps_T = forward_impl.direct_forward_jump(x0, 999)
        correlation = torch.corrcoef(torch.stack([x0.flatten(), x_T.flatten()]))[0, 1]
        if abs(correlation) < 0.1:
            print(f"  ✓ t=T produces uncorrelated noise (correlation: {correlation:.4f})")
        else:
            print(f"  ⚠️ t=T correlation with original: {correlation:.4f} (should be ~0)")
    except:
        print("  t=T test: Not implemented")
    
    print()
    
    # Test 3: Training data generation consistency
    print("Test 3: Training Data Generation")
    
    try:
        generator = DiffusionTrainingDataGenerator(test_scheduler)
        timesteps = []
        
        for _ in range(1000):
            x_t, t, eps = generator.generate_training_sample(test_image)
            timesteps.append(t.item())
        
        # Check uniform distribution of timesteps
        timestep_std = np.std(timesteps)
        expected_std = math.sqrt((test_scheduler.num_timesteps**2 - 1) / 12)  # Uniform distribution std
        
        print(f"  Timestep std: {timestep_std:.2f}, Expected: {expected_std:.2f}")
        if abs(timestep_std - expected_std) < 50:
            print("  ✓ Timestep distribution appears uniform")
        else:
            print("  ⚠️ Timestep distribution may not be uniform")
            
    except Exception as e:
        print(f"  Training data generation test: Not implemented ({e})")
    
    print()
    
    # Test 4: Schedule properties
    print("Test 4: Schedule Properties")
    
    if test_scheduler.betas is not None:
        # Test that betas are in valid range
        if torch.all(test_scheduler.betas > 0) and torch.all(test_scheduler.betas < 1):
            print("  ✓ All betas in valid range (0, 1)")
        else:
            print("  ❌ Some betas outside valid range")
        
        # Test that alpha_cumprod is decreasing
        if test_scheduler.alpha_cumprod is not None:
            alpha_diffs = test_scheduler.alpha_cumprod[1:] - test_scheduler.alpha_cumprod[:-1]
            if torch.all(alpha_diffs <= 0):
                print("  ✓ Alpha cumprod is non-increasing")
            else:
                print("  ❌ Alpha cumprod should be non-increasing")
        else:
            print("  Alpha cumprod not computed")
    else:
        print("  Schedule not implemented")
    
    print()
    
    # Test 5: Reparameterization equivalence
    print("Test 5: Reparameterization Equivalence")
    
    try:
        reparam_demo = ReparameterizationDemo()
        mu = torch.zeros(1, 1, device=device)
        sigma = torch.ones(1, 1, device=device) * 0.5
        
        # Test multiple samples for statistical equivalence
        samples_direct = []
        samples_reparam = []
        
        for _ in range(100):
            try:
                sample_d = reparam_demo.direct_sampling_approach(mu, sigma)
                samples_direct.append(sample_d.item())
            except:
                pass
                
            try:
                sample_r = reparam_demo.reparameterized_approach(mu, sigma)
                samples_reparam.append(sample_r.item())
            except:
                pass
        
        if samples_direct and samples_reparam:
            mean_diff = abs(np.mean(samples_direct) - np.mean(samples_reparam))
            std_diff = abs(np.std(samples_direct) - np.std(samples_reparam))
            
            print(f"  Mean difference: {mean_diff:.4f}")
            print(f"  Std difference: {std_diff:.4f}")
            
            if mean_diff < 0.1 and std_diff < 0.1:
                print("  ✓ Both approaches produce similar statistics")
            else:
                print("  ⚠️ Approaches may not be equivalent")
        else:
            print("  Reparameterization approaches not implemented")
            
    except Exception as e:
        print(f"  Reparameterization test failed: {e}")
    
    print("\n🎉 Validation suite completed!")

# Run comprehensive validation
comprehensive_validation_suite()

---

## Part 7: Numerical Stability Analysis (5 minutes)

### Task 7.1: Handle Edge Cases and Numerical Issues

In [None]:
class NumericalStabilityAnalysis:
    """
    Implement numerical stability checks and fixes.
    Real diffusion implementations must handle numerical edge cases
    that can break training or inference.
    """
    
    def __init__(self, scheduler: NoiseScheduler):
        self.scheduler = scheduler
    
    def analyze_extreme_timesteps(self):
        """
        Analyze behavior at extreme timesteps
        
        Check what happens at the boundaries:
        - t = 0: Should preserve original image exactly
        - t = T: Should produce pure noise
        - Very early timesteps: Barely perceptible corruption
        - Very late timesteps: Nearly pure noise
        """
        print("=== Extreme Timestep Analysis ===\n")
        
        extreme_timesteps = [0, 1, 2, 998, 999, 1000]
        
        for t in extreme_timesteps:
            if t == 0:
                # Special case: no corruption
                signal_coeff = 1.0
                noise_coeff = 0.0
                alpha_cumprod = 1.0
            elif t > self.scheduler.num_timesteps:
                print(f"Timestep {t}: Out of range")
                continue
            else:
                if self.scheduler.alpha_cumprod is not None:
                    t_idx = t - 1
                    alpha_cumprod = self.scheduler.alpha_cumprod[t_idx].item()
                    signal_coeff = math.sqrt(alpha_cumprod)
                    noise_coeff = math.sqrt(1 - alpha_cumprod)
                else:
                    print(f"Timestep {t}: Schedule not computed")
                    continue
            
            # Compute SNR and analyze
            if noise_coeff > 0:
                snr = signal_coeff / noise_coeff
            else:
                snr = float('inf')
            
            print(f"Timestep {t:4d}:")
            print(f"  Alpha cumprod: {alpha_cumprod:.8f}")
            print(f"  Signal coeff:  {signal_coeff:.8f}")
            print(f"  Noise coeff:   {noise_coeff:.8f}")
            print(f"  SNR:          {snr:.4f}")
            print()
    
    def test_precision_limits(self):
        """
        Test numerical precision limits
        
        Check for potential issues:
        - Underflow in alpha_cumprod for large t
        - Overflow in SNR calculations
        - Loss of precision in sqrt operations
        """
        print("=== Numerical Precision Analysis ===\n")
        
        if self.scheduler.alpha_cumprod is None:
            print("Schedule not computed - cannot analyze precision")
            return
        
        # Find where alpha_cumprod becomes dangerously small
        dangerous_threshold = 1e-7
        underflow_timesteps = torch.where(self.scheduler.alpha_cumprod < dangerous_threshold)[0]
        
        if len(underflow_timesteps) > 0:
            first_underflow = underflow_timesteps[0].item() + 1  # Convert to 1-indexed
            print(f"Alpha cumprod drops below {dangerous_threshold} at timestep {first_underflow}")
            print(f"This could cause numerical instability!")
        else:
            print(f"Alpha cumprod stays above {dangerous_threshold} for all timesteps")
        
        # Check for NaN or Inf values
        nan_check = torch.isnan(self.scheduler.alpha_cumprod).any()
        inf_check = torch.isinf(self.scheduler.alpha_cumprod).any()
        
        print(f"Contains NaN values: {nan_check}")
        print(f"Contains Inf values: {inf_check}")
        
        # Test sqrt operations near zero
        min_alpha_cumprod = self.scheduler.alpha_cumprod.min().item()
        print(f"Minimum alpha_cumprod: {min_alpha_cumprod:.2e}")
        print(f"sqrt(min_alpha_cumprod): {math.sqrt(min_alpha_cumprod):.2e}")
    
    def propose_stability_fixes(self):
        """
        Implement common stability fixes
        
        Suggest and implement fixes for numerical issues:
        1. Clamping alpha_cumprod to prevent underflow
        2. Using log-space arithmetic for very small numbers
        3. Adding epsilon to prevent division by zero
        """
        print("=== Stability Fixes ===\n")
        
        if self.scheduler.alpha_cumprod is None:
            print("Schedule not computed - cannot propose fixes")
            return
        
        # Implement clamping fix
        min_clip = 1e-8
        alpha_cumprod_clipped = torch.clamp(self.scheduler.alpha_cumprod, min=min_clip)
        
        print(f"Original min alpha_cumprod: {self.scheduler.alpha_cumprod.min():.2e}")
        print(f"Clipped min alpha_cumprod: {alpha_cumprod_clipped.min():.2e}")
        
        # Implement epsilon addition for division
        eps = 1e-8
        safe_snr = self.scheduler.alpha_cumprod / (1 - self.scheduler.alpha_cumprod + eps)
        
        print(f"Max safe SNR: {safe_snr.max():.2e}")
        
        # Show log-space alternative
        log_alpha_cumprod = torch.log(torch.clamp(self.scheduler.alpha_cumprod, min=1e-10))
        print(f"Log-space range: [{log_alpha_cumprod.min():.2f}, {log_alpha_cumprod.max():.2f}]")

# Test numerical stability (use existing scheduler)
if 'scheduler' in locals() and scheduler.alpha_cumprod is not None:
    stability = NumericalStabilityAnalysis(scheduler)
    stability.analyze_extreme_timesteps()
    stability.test_precision_limits()
    stability.propose_stability_fixes()
else:
    print("No valid scheduler available for numerical stability analysis")

---

## Part 8: Reflection and Integration (10 minutes)

### Task 8.1: Connect Lab to Lecture Concepts

**Discussion Questions** (Work with your partner):

1. **Reparameterization Impact**: 
   - How did implementing both sampling approaches help you understand why the reparameterization trick is essential?
   - What would happen to training if we couldn't use this trick?

2. **Forward Jump Power**:
   - Quantify the computational advantage: How much faster is direct jumping vs sequential?
   - Why is this property unique to Gaussian distributions?

3. **Schedule Design**:
   - Which noise schedule worked best for your test cases and why?
   - How do different schedules affect the signal-to-noise ratio progression?

4. **Mathematical Elegance**:
   - What surprised you most about the mathematical structure?
   - How does the forward process set up the reverse process for success?

In [None]:
def summarize_mathematical_achievements():
    """
    Reflect on the mathematical concepts implemented today
    """
    print("=== Your Mathematical Achievements Today ===\n")
    
    concepts_implemented = [
        "🔧 Reparameterization trick (enabling gradient flow)",
        "📊 Noise schedules (linear, cosine, exponential)",
        "⚡ Forward jump formula (1000x speedup)",
        "🎯 Training data generation (unlimited samples)",
        "📐 Mathematical derivations (Gaussian arithmetic)",
        "🔬 Numerical stability analysis (production-ready)"
    ]
    
    print("Core mathematical concepts you implemented:")
    for concept in concepts_implemented:
        print(f"  {concept}")
    
    print(f"\n🎓 Mathematical foundation completed:")
    print(f"   • Forward diffusion process (corruption)")
    print(f"   • Efficient training data generation") 
    print(f"   • Mathematical rigor and stability")
    print(f"   • Ready for reverse process implementation!")
    
    print(f"\n🔬 Key insights gained:")
    print(f"   • Why reparameterization enables neural network training")
    print(f"   • How Gaussian arithmetic leads to O(1) forward jumps")
    print(f"   • Why different schedules affect training dynamics")
    print(f"   • How mathematical elegance enables practical algorithms")
    
    # Create a visual summary
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    
    # Mathematical progression flowchart
    stages = {
        'Problem': ['Sequential\nSteps', 'No Gradients'],
        'Solution': ['Reparameterization', 'Forward Jumps'],
        'Implementation': ['Noise Schedules', 'Training Pipeline'],
        'Validation': ['Numerical Stability', 'Performance Tests'],
        'Ready': ['Efficient Training', 'Reverse Process']
    }
    
    y_positions = [0.8, 0.65, 0.5, 0.35, 0.2]
    colors = ['red', 'orange', 'blue', 'green', 'purple']
    
    for i, (stage, concepts) in enumerate(stages.items()):
        y = y_positions[i]
        color = colors[i]
        
        # Stage label
        ax.text(0.05, y, stage, fontsize=14, weight='bold', color=color)
        
        # Concepts
        x_positions = [0.3 + j * 0.25 for j in range(len(concepts))]
        for j, concept in enumerate(concepts):
            ax.text(x_positions[j], y, concept, fontsize=11, ha='center',
                   bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.3))
    
    # Add arrows showing progression
    for i in range(len(y_positions) - 1):
        ax.arrow(0.5, y_positions[i] - 0.05, 0, -0.05, head_width=0.02, head_length=0.01,
                fc='gray', ec='gray', alpha=0.7)
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('Forward Diffusion Implementation Journey', fontsize=16, weight='bold')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Performance summary function (provided)
def performance_summary():
    """Summarize the performance gains achieved"""
    print("\n=== Performance Gains Summary ===\n")
    
    print("🚀 Computational Revolution:")
    print("   • Sequential approach: O(T) operations")
    print("   • Forward jump approach: O(1) operations")  
    print("   • Typical speedup: 100-1000x for T=1000")
    print("   • Memory usage: Constant vs linear")
    
    print("\n⚡ Training Efficiency:")
    print("   • Unlimited training data from any image")
    print("   • Random timestep sampling")
    print("   • Batch processing support")
    print("   • GPU acceleration ready")
    
    print("\n🎯 Mathematical Elegance:")
    print("   • Gaussian arithmetic enables direct jumps")
    print("   • Reparameterization preserves gradients")
    print("   • Schedule design controls corruption dynamics")
    print("   • Numerical stability ensures robust training")

# Create comprehensive summary
summarize_mathematical_achievements()
performance_summary()

---

## Implementation Checklist

### Core Mathematical Functions (Students Implement):

**✅ Essential TODOs:**
- [ ] `direct_sampling_approach()` - Direct Gaussian sampling (wrong way)
- [ ] `reparameterized_approach()` - Reparameterization trick (right way)
- [ ] `forward_step_reparameterized()` - Diffusion-specific reparameterization
- [ ] `linear_schedule()` - Linear noise schedule
- [ ] `cosine_schedule()` - Cosine noise schedule  
- [ ] `exponential_schedule()` - Exponential noise schedule
- [ ] `precompute_schedule()` - Derived quantities computation
- [ ] `sequential_forward_process()` - Sequential corruption (slow way)
- [ ] `direct_forward_jump()` - Direct jump (fast way)
- [ ] `generate_training_sample()` - Core training data generation

**✅ Provided Starter Code:**
- [ ] All visualization functions with complete plotting
- [ ] Benchmarking and performance analysis
- [ ] Statistical testing and validation
- [ ] Numerical stability analysis
- [ ] Comprehensive test suites

---

## Submission Requirements

### What to Submit

Submit your completed Jupyter notebook (.ipynb file) with:

**✅ Mathematical Implementations:**
- All TODO functions implemented with correct formulas
- Clear comments explaining mathematical steps
- Proper handling of tensor operations and device placement

**✅ Validation Results:**
- Screenshots or outputs showing successful validation tests
- Performance benchmarks comparing sequential vs direct approaches
- Analysis of different noise schedules

**✅ Understanding Demonstration:**
- Answers to discussion questions with your partner
- Explanation of why reparameterization is essential
- Analysis of computational speedups achieved

**✅ Code Quality:**
- Clean, well-commented implementations
- Proper error handling for edge cases
- Professional coding standards

---

## Quick Reference: Key Formulas Implemented

**Forward Jump Formula:**

In [None]:
x_t = sqrt(alpha_cumprod_t) * x0 + sqrt(1 - alpha_cumprod_t) * epsilon

**Reparameterization Trick:**

In [None]:
z = mu + sigma * torch.randn_like(mu)  # Instead of sampling from Normal(mu, sigma)

**Alpha Cumulative Product:**

In [None]:
alpha_cumprod = torch.cumprod(1 - betas, dim=0)

**Training Sample Generation:**

In [None]:
t = torch.randint(1, T+1, (batch_size,))
epsilon = torch.randn_like(x0)
x_t = sqrt_alpha_cumprod[t] * x0 + sqrt_one_minus_alpha_cumprod[t] * epsilon

---

