# Lab 5: Sampling from Trained Diffusion Models - From Noise to Data
**Course: Diffusion Models: Theory and Applications**  
**Duration: 90 minutes**  
**Team Size: 2 students (same teams from Labs 1-4)**

---

## Learning Objectives
By the end of this lab, students will be able to:
1. **Implement** the complete DDPM stochastic sampling algorithm
2. **Build** the DDIM deterministic sampling method with step skipping
3. **Create** the noise schedule reconstruction approach used in DDIM
4. **Construct** controllable stochasticity with the η parameter
5. **Analyze** speed vs quality trade-offs in practical sampling
6. **Optimize** sampling algorithms for real-world deployment scenarios

---

## Lab Setup and Sampling Framework

### Part 1: Team Setup & Sampling Mission (10 minutes)

In [None]:
# Sampling implementation setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
from torch.distributions import Normal
from typing import Tuple, Dict, List, Optional
import time
from dataclasses import dataclass

# Set seeds for reproducible sampling
torch.manual_seed(42)
np.random.seed(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Sampling experiments on: {device}")

@dataclass
class SamplingConfig:
    """Configuration for sampling experiments"""
    T: int = 100
    beta_start: float = 1e-4
    beta_end: float = 2e-2
    img_size: int = 32
    channels: int = 3
    
    def __post_init__(self):
        # Compute noise schedule
        self.betas = torch.linspace(self.beta_start, self.beta_end, self.T).to(device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([torch.ones(1).to(device), self.alphas_cumprod[:-1]])
        
        # Variance schedule for sampling
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )

# Create sampling configuration
config = SamplingConfig(T=50, img_size=8, channels=1)  # Small for faster experimentation
print(f"Sampling configuration: T={config.T}, image_size={config.img_size}x{config.img_size}")

# Simple 2D visualization data for understanding
def create_sampling_test_data(n_samples: int = 200) -> torch.Tensor:
    """Create 2D test data for sampling visualization"""
    # Create a simple flower pattern
    t = torch.linspace(0, 4*math.pi, n_samples)
    r = 2 + 0.5 * torch.sin(3*t)
    x = r * torch.cos(t) + 0.2 * torch.randn(n_samples)
    y = r * torch.sin(t) + 0.2 * torch.randn(n_samples)
    data = torch.stack([x, y], dim=1)
    return data.to(device)

test_data_2d = create_sampling_test_data(200)
print(f"Test data shape: {test_data_2d.shape}")

# Visualize test data
plt.figure(figsize=(8, 6))
plt.scatter(test_data_2d[:, 0].cpu(), test_data_2d[:, 1].cpu(), alpha=0.7, s=30, c='blue')
plt.title('Test Data: Flower Pattern for Sampling Experiments')
plt.xlabel('X')
plt.ylabel('Y')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.show()

# Pre-trained model simulation (we'll simulate with a simple network)
class SimpleNoisePredictor(nn.Module):
    """
    Simple noise prediction network for 2D data
    Simulates a trained diffusion model
    """
    def __init__(self, data_dim: int = 2):
        super().__init__()
        self.data_dim = data_dim
        
        # Simple MLP with timestep embedding
        self.net = nn.Sequential(
            nn.Linear(data_dim + 1, 64),  # +1 for timestep
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, data_dim)
        )
    
    def timestep_embedding(self, t: torch.Tensor, max_period: int = 10000) -> torch.Tensor:
        """Simple sinusoidal timestep embedding"""
        if t.dim() == 0:
            t = t.unsqueeze(0)
        
        # Simple normalized embedding for this lab
        return (t.float() / config.T).unsqueeze(-1)
    
    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """Predict noise given noisy input and timestep"""
        batch_size = x.shape[0]
        
        if t.dim() == 0:
            t = t.repeat(batch_size)
        
        t_embed = self.timestep_embedding(t)
        if len(t_embed.shape) == 1:
            t_embed = t_embed.unsqueeze(0).repeat(batch_size, 1)
        
        # Concatenate input and timestep embedding
        input_with_time = torch.cat([x, t_embed], dim=-1)
        return self.net(input_with_time)

# Create and "train" a simple model (we'll use a pre-initialized model for this lab)
pretrained_model = SimpleNoisePredictor(data_dim=2).to(device)
print("✓ Pre-trained model loaded (simulated)")

---

## Part 2: Understanding the Trained Model (15 minutes)

### Task 2.1: Explore the Pre-trained Noise Predictor

**Your Mission**: Understand how the trained model predicts noise and verify its behavior.

In [None]:
class TrainedModelAnalyzer:
    """
    Analyze the behavior of our trained noise prediction model.
    Understanding this is crucial for implementing sampling algorithms.
    """
    
    def __init__(self, model: nn.Module, config: SamplingConfig):
        self.model = model
        self.config = config
        
    def analyze_noise_prediction_quality(self, x_clean: torch.Tensor, timesteps: List[int]):
        """
        TODO: Implement noise prediction analysis
        
        For given clean data and timesteps:
        1. Add known noise to create noisy samples
        2. Use model to predict the noise
        3. Compare predicted vs actual noise
        4. Analyze how prediction quality varies with timestep
        
        Args:
            x_clean: Clean data samples
            timesteps: List of timesteps to test
            
        Returns:
            Dictionary with analysis results
        """
        # TODO: Your implementation here
        # Step 1: For each timestep, add noise using forward process
        # Step 2: Use model to predict the noise
        # Step 3: Compute MSE between predicted and actual noise
        # Step 4: Return analysis of prediction quality vs timestep
        pass
    
    def visualize_noise_predictions(self, x_clean: torch.Tensor, timesteps: List[int] = [5, 15, 25, 35, 45]):
        """
        Visualize how well the model predicts noise at different timesteps
        """
        fig, axes = plt.subplots(2, len(timesteps), figsize=(15, 8))
        
        with torch.no_grad():
            for i, t in enumerate(timesteps):
                try:
                    # Add noise to clean sample
                    alpha_cumprod_t = self.config.alphas_cumprod[t]
                    noise = torch.randn_like(x_clean)
                    x_noisy = torch.sqrt(alpha_cumprod_t) * x_clean + torch.sqrt(1 - alpha_cumprod_t) * noise
                    
                    # Predict noise
                    predicted_noise = self.model(x_noisy, torch.tensor(t).to(device))
                    
                    # Plot true vs predicted noise
                    axes[0, i].scatter(noise[:, 0].cpu(), noise[:, 1].cpu(), alpha=0.6, s=20, label='True noise')
                    axes[0, i].scatter(predicted_noise[:, 0].cpu(), predicted_noise[:, 1].cpu(), 
                                     alpha=0.6, s=20, label='Predicted noise')
                    axes[0, i].set_title(f't={t}')
                    axes[0, i].legend(fontsize=8)
                    axes[0, i].grid(True, alpha=0.3)
                    
                    # Plot error magnitude
                    error = torch.norm(noise - predicted_noise, dim=1)
                    axes[1, i].hist(error.cpu().numpy(), bins=20, alpha=0.7, color='red')
                    axes[1, i].set_title(f'Error: mean={error.mean():.3f}')
                    axes[1, i].grid(True, alpha=0.3)
                    
                except Exception as e:
                    axes[0, i].text(0.5, 0.5, 'Model\nanalysis\nneeded', ha='center', va='center')
                    axes[1, i].text(0.5, 0.5, 'Implement\nTODOs', ha='center', va='center')
        
        axes[0, 0].set_ylabel('Noise Predictions')
        axes[1, 0].set_ylabel('Error Distribution')
        plt.tight_layout()
        plt.show()
    
    def test_model_consistency(self, x_clean: torch.Tensor):
        """
        Test that the model gives consistent predictions
        """
        print("=== Model Consistency Analysis ===\n")
        
        # Test 1: Same input should give same output
        t = 20
        with torch.no_grad():
            alpha_cumprod_t = self.config.alphas_cumprod[t]
            noise = torch.randn_like(x_clean)
            x_noisy = torch.sqrt(alpha_cumprod_t) * x_clean + torch.sqrt(1 - alpha_cumprod_t) * noise
            
            pred1 = self.model(x_noisy, torch.tensor(t).to(device))
            pred2 = self.model(x_noisy, torch.tensor(t).to(device))
            
            consistency_error = torch.norm(pred1 - pred2).item()
            print(f"Deterministic consistency error: {consistency_error:.8f}")
            print("✓ Model is deterministic" if consistency_error < 1e-6 else "❌ Model is non-deterministic")
        
        # Test 2: Different timesteps should give different predictions
        timesteps = [10, 20, 30, 40]
        predictions = []
        
        with torch.no_grad():
            for t in timesteps:
                alpha_cumprod_t = self.config.alphas_cumprod[t]
                x_noisy = torch.sqrt(alpha_cumprod_t) * x_clean + torch.sqrt(1 - alpha_cumprod_t) * noise
                pred = self.model(x_noisy, torch.tensor(t).to(device))
                predictions.append(pred)
        
        # Compare predictions across timesteps
        for i in range(len(timesteps)-1):
            diff = torch.norm(predictions[i] - predictions[i+1]).item()
            print(f"Difference t={timesteps[i]} vs t={timesteps[i+1]}: {diff:.3f}")
        
        print("✓ Model is timestep-aware\n")

# Test the trained model analyzer (uncomment after implementing TODOs)
# analyzer = TrainedModelAnalyzer(pretrained_model, config)
# analyzer.test_model_consistency(test_data_2d[:5])
# analyzer.visualize_noise_predictions(test_data_2d[:5])

# Test noise prediction quality
# quality_results = analyzer.analyze_noise_prediction_quality(test_data_2d[:10], [5, 15, 25, 35, 45])

### Task 2.2: Forward Process Implementation

In [None]:
class ForwardProcessImpl:
    """
    Implement the forward process for creating training data and understanding sampling.
    This helps us understand what the reverse process needs to undo.
    """
    
    def __init__(self, config: SamplingConfig):
        self.config = config
    
    def add_noise(self, x_start: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        TODO: Implement the forward noising process
        
        Apply q(x_t | x_0) = N(√ᾱ_t x_0, (1-ᾱ_t) I) to add noise to clean data.
        
        Args:
            x_start: Clean data
            t: Timestep (can be tensor for batch processing)
            noise: Optional pre-sampled noise (if None, sample new)
            
        Returns:
            Tuple of (noisy_data, noise_used)
        """
        # TODO: Your implementation here
        # Step 1: Handle timestep indexing (ensure t is properly shaped)
        # Step 2: Extract α̅_t from config.alphas_cumprod
        # Step 3: Sample noise if not provided
        # Step 4: Apply the forward process formula
        # Step 5: Return both noisy data and noise used
        pass
    
    def demonstrate_forward_trajectory(self, x_start: torch.Tensor, timesteps: List[int]):
        """
        Show how data gets progressively noisier through the forward process
        """
        print("=== Forward Process Trajectory ===\n")
        
        fig, axes = plt.subplots(2, len(timesteps), figsize=(15, 8))
        
        noise = torch.randn_like(x_start)  # Use same noise for consistency
        
        for i, t in enumerate(timesteps):
            try:
                x_noisy, _ = self.add_noise(x_start, torch.tensor(t), noise)
                
                # Plot noisy data
                axes[0, i].scatter(x_noisy[:, 0].cpu(), x_noisy[:, 1].cpu(), alpha=0.6, s=20, c='red')
                axes[0, i].scatter(x_start[:, 0].cpu(), x_start[:, 1].cpu(), alpha=0.3, s=10, c='blue')
                axes[0, i].set_title(f't={t}')
                axes[0, i].grid(True, alpha=0.3)
                axes[0, i].set_xlim(-6, 6)
                axes[0, i].set_ylim(-6, 6)
                
                # Plot noise level
                alpha_cumprod = self.config.alphas_cumprod[t]
                signal_strength = torch.sqrt(alpha_cumprod).item()
                noise_strength = torch.sqrt(1 - alpha_cumprod).item()
                
                axes[1, i].bar(['Signal', 'Noise'], [signal_strength, noise_strength], 
                              color=['blue', 'red'], alpha=0.7)
                axes[1, i].set_title(f'Signal: {signal_strength:.2f}\nNoise: {noise_strength:.2f}')
                axes[1, i].set_ylim(0, 1)
                
            except Exception as e:
                axes[0, i].text(0.5, 0.5, 'Implement\nadd_noise\nfirst', ha='center', va='center')
                axes[1, i].text(0.5, 0.5, 'Implement\nTODOs', ha='center', va='center')
        
        axes[0, 0].set_ylabel('Data Trajectory')
        axes[1, 0].set_ylabel('Signal/Noise Ratio')
        plt.tight_layout()
        plt.show()

# Test forward process implementation (uncomment after implementing TODOs)
# forward_process = ForwardProcessImpl(config)
# forward_process.demonstrate_forward_trajectory(test_data_2d[:20], [0, 10, 20, 30, 40, 49])

---

## Part 3: DDPM Stochastic Sampling Implementation (25 minutes)

### Task 3.1: Implement the Complete DDPM Sampler

**Your Mission**: Build the original DDPM sampling algorithm with proper stochastic sampling.

In [None]:
class DDPMSampler:
    """
    Implement the complete DDPM sampling algorithm.
    This is the original stochastic approach to diffusion sampling.
    """
    
    def __init__(self, model: nn.Module, config: SamplingConfig):
        self.model = model
        self.config = config
        
    def predict_noise(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """Use the trained model to predict noise"""
        return self.model(x_t, t)
    
    def compute_posterior_mean(self, x_t: torch.Tensor, t: int, predicted_noise: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement posterior mean computation for DDPM
        
        Compute μ_θ(x_t, t) = (1/√α_t) * (x_t - (1-α_t)/√(1-ᾱ_t) * ε_θ(x_t, t))
        
        This is the mean of the reverse process distribution.
        
        Args:
            x_t: Current noisy state
            t: Current timestep
            predicted_noise: Noise predicted by the model
            
        Returns:
            Posterior mean μ_θ(x_t, t)
        """
        # TODO: Your implementation here
        # Step 1: Extract α_t and ᾱ_t from noise schedule
        # Step 2: Compute coefficient for x_t term: 1/√α_t
        # Step 3: Compute coefficient for noise term: (1-α_t)/√(1-ᾱ_t)
        # Step 4: Apply the posterior mean formula
        # Step 5: Return computed mean
        pass
    
    def compute_posterior_variance(self, t: int) -> torch.Tensor:
        """
        TODO: Implement posterior variance computation
        
        Compute σ̃²_t = β_t * (1-ᾱ_{t-1})/(1-ᾱ_t)
        
        This is the fixed variance of the reverse process.
        
        Args:
            t: Current timestep
            
        Returns:
            Posterior variance σ̃²_t
        """
        # TODO: Your implementation here
        # Step 1: Handle edge case for t=0
        # Step 2: Extract β_t, ᾱ_t, and ᾱ_{t-1} from config
        # Step 3: Apply the posterior variance formula
        # Step 4: Return computed variance
        pass
    
    def ddpm_step(self, x_t: torch.Tensor, t: int) -> torch.Tensor:
        """
        TODO: Implement single DDPM sampling step
        
        Execute one step of the DDPM reverse process:
        1. Predict noise using the model
        2. Compute posterior mean
        3. Compute posterior variance  
        4. Sample from the posterior distribution
        
        Args:
            x_t: Current state at timestep t
            t: Current timestep
            
        Returns:
            x_{t-1}: Next state in the reverse process
        """
        # TODO: Your implementation here
        # Step 1: Predict noise using the model
        # Step 2: Compute posterior mean
        # Step 3: Compute posterior variance
        # Step 4: Sample from N(mean, variance) if t > 0, else return mean
        # Step 5: Return the next state
        pass
    
    def sample(self, shape: Tuple[int, ...], return_trajectory: bool = False) -> torch.Tensor:
        """
        TODO: Implement complete DDPM sampling
        
        Generate samples by running the reverse process from pure noise to data.
        
        Args:
            shape: Shape of samples to generate
            return_trajectory: Whether to return intermediate states
            
        Returns:
            Generated samples (and trajectory if requested)
        """
        # TODO: Your implementation here
        # Step 1: Initialize x_T from pure noise N(0, I)
        # Step 2: For t = T-1, T-2, ..., 0: apply ddpm_step
        # Step 3: Optionally store trajectory
        # Step 4: Return final samples (and trajectory if requested)
        pass
    
    def sample_with_progress(self, shape: Tuple[int, ...], show_every: int = 10) -> torch.Tensor:
        """
        Sample with visualization of the denoising process
        """
        print("=== DDPM Sampling with Progress ===\n")
        
        # Initialize
        x = torch.randn(shape).to(device)
        trajectory = [x.clone()]
        
        # Reverse process
        for t in reversed(range(self.config.T)):
            try:
                x = self.ddpm_step(x, t)
                if t % show_every == 0 or t == 0:
                    trajectory.append(x.clone())
                    print(f"Step {self.config.T - t}: t={t}, sample mean={x.mean().item():.3f}, std={x.std().item():.3f}")
            except:
                print(f"Implement ddpm_step to see progress at t={t}")
                break
        
        # Visualize trajectory
        if len(trajectory) > 1:
            self.visualize_sampling_trajectory(trajectory, "DDPM Sampling Progress")
        
        return x
    
    def visualize_sampling_trajectory(self, trajectory: List[torch.Tensor], title: str):
        """Visualize the sampling trajectory"""
        n_steps = len(trajectory)
        cols = min(6, n_steps)
        rows = (n_steps + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(15, 3*rows))
        if rows == 1:
            axes = axes.reshape(1, -1)
        
        for i, x in enumerate(trajectory):
            row, col = i // cols, i % cols
            if i < len(trajectory):
                axes[row, col].scatter(x[:, 0].cpu(), x[:, 1].cpu(), alpha=0.6, s=20)
                step_num = i * 10 if i < len(trajectory)-1 else "Final"
                axes[row, col].set_title(f'Step {step_num}')
                axes[row, col].grid(True, alpha=0.3)
                axes[row, col].set_xlim(-6, 6)
                axes[row, col].set_ylim(-6, 6)
        
        # Hide empty subplots
        for i in range(len(trajectory), rows * cols):
            row, col = i // cols, i % cols
            axes[row, col].axis('off')
        
        plt.suptitle(title)
        plt.tight_layout()
        plt.show()

# Test DDPM sampler (uncomment after implementing TODOs)
# ddpm_sampler = DDPMSampler(pretrained_model, config)

# # Test individual components
# test_x = torch.randn(5, 2).to(device)
# test_t = 20
# predicted_noise = ddpm_sampler.predict_noise(test_x, torch.tensor(test_t))
# print(f"Predicted noise shape: {predicted_noise.shape}")

# # Test posterior computations
# posterior_mean = ddpm_sampler.compute_posterior_mean(test_x, test_t, predicted_noise)
# posterior_var = ddpm_sampler.compute_posterior_variance(test_t)
# print(f"Posterior mean shape: {posterior_mean.shape if posterior_mean is not None else 'Not implemented'}")
# print(f"Posterior variance: {posterior_var.item() if posterior_var is not None else 'Not implemented'}")

# # Test complete sampling
# samples = ddmp_sampler.sample_with_progress((20, 2), show_every=10)

### Task 3.2: DDPM Analysis and Validation

In [None]:
class DDPMAnalyzer:
    """
    Analyze DDPM sampling behavior and validate implementation correctness.
    """
    
    def __init__(self, sampler: DDPMSampler):
        self.sampler = sampler
        
    def validate_sampling_implementation(self):
        """
        Validate that DDPM implementation is working correctly
        """
        print("=== DDPM Implementation Validation ===\n")
        
        # Test 1: Check that sampling produces reasonable outputs
        try:
            samples = self.sampler.sample((10, 2))
            if samples is not None:
                print(f"✓ Sample generation successful")
                print(f"  Sample shape: {samples.shape}")
                print(f"  Sample mean: {samples.mean().item():.3f}")
                print(f"  Sample std: {samples.std().item():.3f}")
            else:
                print("❌ Implement DDPM sample() method")
        except Exception as e:
            print(f"❌ Error in sampling: {e}")
        
        # Test 2: Check posterior variance behavior
        variances = []
        for t in [0, 10, 20, 30, 40]:
            try:
                var = self.sampler.compute_posterior_variance(t)
                if var is not None:
                    variances.append((t, var.item()))
            except:
                pass
        
        if variances:
            print(f"\n✓ Posterior variance computed successfully")
            for t, var in variances:
                print(f"  t={t}: σ̃² = {var:.6f}")
        else:
            print(f"\n❌ Implement compute_posterior_variance method")
        
        # Test 3: Check that final step is deterministic
        try:
            test_x = torch.randn(3, 2).to(device)
            step1 = self.sampler.ddpm_step(test_x, 0)  # Final step
            step2 = self.sampler.ddmp_step(test_x, 0)  # Should be identical
            
            if step1 is not None and step2 is not None:
                diff = torch.norm(step1 - step2).item()
                print(f"\n✓ Final step determinism check: diff = {diff:.8f}")
                print("✓ Final step is deterministic" if diff < 1e-6 else "❌ Final step should be deterministic")
        except:
            print(f"\n❌ Implement ddpm_step method")
    
    def analyze_stochasticity(self, n_runs: int = 5):
        """
        Analyze the stochastic behavior of DDPM sampling
        """
        print("=== DDPM Stochasticity Analysis ===\n")
        
        samples_list = []
        
        for run in range(n_runs):
            try:
                samples = self.sampler.sample((20, 2))
                if samples is not None:
                    samples_list.append(samples)
                else:
                    print("❌ Implement sample() method first")
                    return
            except:
                print("❌ Error in sampling - implement missing methods")
                return
        
        # Analyze diversity across runs
        print(f"Generated {n_runs} independent sample sets")
        
        # Compute pairwise differences between runs
        for i in range(n_runs-1):
            diff = torch.norm(samples_list[i] - samples_list[i+1]).item()
            print(f"Difference run {i+1} vs {i+2}: {diff:.3f}")
        
        # Visualize different runs
        fig, axes = plt.subplots(1, min(n_runs, 5), figsize=(15, 3))
        if n_runs == 1:
            axes = [axes]
        
        for i, samples in enumerate(samples_list[:5]):
            axes[i].scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), alpha=0.6, s=20)
            axes[i].set_title(f'Run {i+1}')
            axes[i].grid(True, alpha=0.3)
            axes[i].set_xlim(-6, 6)
            axes[i].set_ylim(-6, 6)
        
        plt.suptitle('DDPM Sample Diversity Across Runs')
        plt.tight_layout()
        plt.show()
        
        print("✓ DDPM produces diverse samples due to stochastic sampling")

# Test DDPM analyzer (uncomment after implementing DDPM methods)
# ddpm_analyzer = DDPMAnalyzer(ddpm_sampler)
# ddpm_analyzer.validate_sampling_implementation()
# ddpm_analyzer.analyze_stochasticity(n_runs=3)

---

## Part 4: DDIM Deterministic Sampling Implementation (25 minutes)

### Task 4.1: Implement the DDIM Sampler

**Your Mission**: Build the DDIM sampling algorithm with step skipping and controllable stochasticity.

In [None]:
class DDIMSampler:
    """
    Implement the DDIM sampling algorithm.
    The key innovation: deterministic sampling with step skipping capability.
    """
    
    def __init__(self, model: nn.Module, config: SamplingConfig):
        self.model = model
        self.config = config
        
    def predict_x0_from_eps(self, x_t: torch.Tensor, eps: torch.Tensor, t: int) -> torch.Tensor:
        """
        TODO: Implement x0 prediction from noise
        
        Given x_t and predicted noise, recover the estimated clean image:
        x̂_0 = (x_t - √(1-ᾱ_t) * ε) / √ᾱ_t
        
        Args:
            x_t: Current noisy state
            eps: Predicted noise
            t: Current timestep
            
        Returns:
            Predicted clean data x̂_0
        """
        # TODO: Your implementation here
        # Step 1: Extract ᾱ_t from config
        # Step 2: Apply the inversion formula
        # Step 3: Return predicted x_0
        pass
    
    def predict_eps_from_x0(self, x_t: torch.Tensor, x0: torch.Tensor, t: int) -> torch.Tensor:
        """
        TODO: Implement noise prediction from x0
        
        Given x_t and x_0, recover the noise:
        ε = (x_t - √ᾱ_t * x_0) / √(1-ᾱ_t)
        
        Args:
            x_t: Current noisy state
            x0: Clean data estimate
            t: Current timestep
            
        Returns:
            Predicted noise ε
        """
        # TODO: Your implementation here
        # Step 1: Extract ᾱ_t from config
        # Step 2: Apply the noise extraction formula
        # Step 3: Return predicted noise
        pass
    
    def ddim_step(self, x_t: torch.Tensor, t: int, s: int, eta: float = 0.0) -> torch.Tensor:
        """
        TODO: Implement single DDIM sampling step
        
        The core DDIM update: deterministic reconstruction of the trajectory.
        
        When eta=0 (deterministic):
        x_s = √ᾱ_s * x̂_0 + √(1-ᾱ_s) * ε̂
        
        When eta>0 (stochastic):
        Add controlled randomness with variance σ_t^2 = eta^2 * β̃_t
        
        Args:
            x_t: Current state at timestep t
            t: Current timestep  
            s: Target timestep (s < t)
            eta: Stochasticity parameter (0=deterministic, 1=like DDPM)
            
        Returns:
            x_s: Next state in the reverse process
        """
        # TODO: Your implementation here
        # Step 1: Predict noise using the model
        # Step 2: Predict x_0 from the noise
        # Step 3: Compute the deterministic part: √ᾱ_s * x̂_0 + √(1-ᾱ_s) * ε̂
        # Step 4: Add stochastic part if eta > 0
        # Step 5: Return the next state
        pass
    
    def sample(self, shape: Tuple[int, ...], timesteps: Optional[List[int]] = None, 
               eta: float = 0.0, return_trajectory: bool = False) -> torch.Tensor:
        """
        TODO: Implement complete DDIM sampling with step skipping
        
        Generate samples using DDIM with arbitrary timestep scheduling.
        
        Args:
            shape: Shape of samples to generate
            timesteps: Custom timestep schedule (if None, use all timesteps)
            eta: Stochasticity parameter
            return_trajectory: Whether to return intermediate states
            
        Returns:
            Generated samples (and trajectory if requested)
        """
        # TODO: Your implementation here
        # Step 1: Set up timestep schedule (default to uniform if not provided)
        # Step 2: Initialize x_T from pure noise
        # Step 3: For each consecutive pair (t, s) in timesteps: apply ddim_step
        # Step 4: Optionally store trajectory
        # Step 5: Return final samples (and trajectory if requested)
        pass
    
    def create_timestep_schedule(self, num_steps: int, schedule_type: str = "uniform") -> List[int]:
        """
        Create different timestep schedules for step skipping
        
        Args:
            num_steps: Number of sampling steps to use
            schedule_type: Type of schedule ("uniform", "quadratic", "custom")
            
        Returns:
            List of timesteps to use for sampling
        """
        if schedule_type == "uniform":
            # Uniform spacing
            step_size = self.config.T // num_steps
            timesteps = list(range(self.config.T - 1, -1, -step_size))
            timesteps.append(0)  # Ensure we end at 0
            return timesteps[:num_steps + 1]
        
        elif schedule_type == "quadratic":
            # More steps at high noise levels
            timesteps = []
            for i in range(num_steps):
                # Quadratic spacing
                t = int(self.config.T * (1 - (i / num_steps) ** 2))
                timesteps.append(max(0, t))
            return sorted(set(timesteps), reverse=True)
        
        else:
            # Default to uniform
            return self.create_timestep_schedule(num_steps, "uniform")
    
    def sample_with_schedule_analysis(self, shape: Tuple[int, ...], step_counts: List[int], eta: float = 0.0):
        """
        Analyze sampling with different step counts
        """
        print("=== DDIM Schedule Analysis ===\n")
        
        results = {}
        
        for num_steps in step_counts:
            print(f"Testing with {num_steps} steps...")
            timesteps = self.create_timestep_schedule(num_steps)
            
            start_time = time.time()
            try:
                samples = self.sample(shape, timesteps, eta)
                sampling_time = time.time() - start_time
                
                if samples is not None:
                    results[num_steps] = {
                        'samples': samples,
                        'time': sampling_time,
                        'timesteps': timesteps
                    }
                    print(f"  ✓ {num_steps} steps completed in {sampling_time:.3f}s")
                else:
                    print(f"  ❌ Implement DDIM sample() method")
                    break
            except Exception as e:
                print(f"  ❌ Error with {num_steps} steps: {e}")
                break
        
        if results:
            self.visualize_schedule_comparison(results)
            self.analyze_speed_quality_tradeoff(results)
    
    def visualize_schedule_comparison(self, results: Dict):
        """Visualize samples from different step schedules"""
        n_schedules = len(results)
        fig, axes = plt.subplots(1, n_schedules, figsize=(4*n_schedules, 4))
        
        if n_schedules == 1:
            axes = [axes]
        
        for i, (num_steps, result) in enumerate(results.items()):
            samples = result['samples']
            axes[i].scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), alpha=0.6, s=20)
            axes[i].set_title(f'{num_steps} steps\n{result["time"]:.2f}s')
            axes[i].grid(True, alpha=0.3)
            axes[i].set_xlim(-6, 6)
            axes[i].set_ylim(-6, 6)
        
        plt.suptitle('DDIM Sampling: Different Step Counts')
        plt.tight_layout()
        plt.show()
    
    def analyze_speed_quality_tradeoff(self, results: Dict):
        """Analyze the speed vs quality tradeoff"""
        step_counts = list(results.keys())
        times = [results[k]['time'] for k in step_counts]
        
        # Simple quality metric: how close to a circle (for our flower pattern)
        qualities = []
        for num_steps in step_counts:
            samples = results[num_steps]['samples']
            # Measure distance from origin (crude quality metric)
            distances = torch.norm(samples, dim=1)
            quality = -distances.std().item()  # Lower std = more circular = higher quality
            qualities.append(quality)
        
        # Plot tradeoff
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        # Speed vs steps
        ax1.plot(step_counts, times, 'bo-', linewidth=2, markersize=8)
        ax1.set_xlabel('Number of Steps')
        ax1.set_ylabel('Sampling Time (s)')
        ax1.set_title('Sampling Speed vs Steps')
        ax1.grid(True, alpha=0.3)
        
        # Quality vs steps  
        ax2.plot(step_counts, qualities, 'ro-', linewidth=2, markersize=8)
        ax2.set_xlabel('Number of Steps')
        ax2.set_ylabel('Quality Metric')
        ax2.set_title('Sample Quality vs Steps')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        print(f"Speed improvement from {max(step_counts)} to {min(step_counts)} steps: "
              f"{max(times)/min(times):.1f}x faster")

# Test DDIM sampler (uncomment after implementing TODOs)
# ddim_sampler = DDIMSampler(pretrained_model, config)

# # Test timestep scheduling
# uniform_schedule = ddim_sampler.create_timestep_schedule(20, "uniform")
# quadratic_schedule = ddim_sampler.create_timestep_schedule(20, "quadratic") 
# print(f"Uniform schedule (20 steps): {uniform_schedule[:10]}...")
# print(f"Quadratic schedule (20 steps): {quadratic_schedule[:10]}...")

# # Test schedule analysis
# ddim_sampler.sample_with_schedule_analysis((30, 2), step_counts=[50, 25, 10], eta=0.0)

### Task 4.2: Controllable Stochasticity with η Parameter

**Your Mission**: Implement the η parameter to control the deterministic vs stochastic behavior.

In [None]:
class StochasticityController:
    """
    Implement and analyze the η parameter that controls DDIM stochasticity.
    This is a key innovation that allows trading off speed vs diversity.
    """
    
    def __init__(self, ddim_sampler: DDIMSampler):
        self.ddim_sampler = ddim_sampler
        
    def compute_stochastic_variance(self, t: int, s: int, eta: float) -> torch.Tensor:
        """
        TODO: Implement stochastic variance computation for DDIM
        
        Compute σ_t^2 = η^2 * β̃_{t→s} where β̃_{t→s} is the "effective" beta
        for the jump from t to s.
        
        Formula: β̃_{t→s} = (1-ᾱ_s)/(1-ᾱ_t) * (1 - ᾱ_t/ᾱ_s)
        
        Args:
            t: Current timestep
            s: Target timestep  
            eta: Stochasticity parameter
            
        Returns:
            Variance for stochastic sampling
        """
        # TODO: Your implementation here
        # Step 1: Extract ᾱ_t and ᾱ_s from config
        # Step 2: Compute effective beta β̃_{t→s}
        # Step 3: Apply eta scaling: σ^2 = η^2 * β̃
        # Step 4: Return computed variance
        pass
    
    def analyze_eta_effects(self, shape: Tuple[int, ...], eta_values: List[float], num_steps: int = 20):
        """
        Analyze how different η values affect sampling behavior
        """
        print("=== η Parameter Analysis ===\n")
        
        timesteps = self.ddim_sampler.create_timestep_schedule(num_steps)
        results = {}
        
        for eta in eta_values:
            print(f"Testing η = {eta}")
            
            try:
                # Generate multiple samples to assess diversity
                samples_list = []
                for run in range(3):
                    samples = self.ddim_sampler.sample(shape, timesteps, eta)
                    if samples is not None:
                        samples_list.append(samples)
                    else:
                        print(f"  ❌ Implement DDIM sample() method")
                        return
                
                results[eta] = samples_list
                print(f"  ✓ Generated {len(samples_list)} sample sets")
                
            except Exception as e:
                print(f"  ❌ Error with η={eta}: {e}")
        
        if results:
            self.visualize_eta_comparison(results)
            self.analyze_diversity_vs_eta(results)
    
    def visualize_eta_comparison(self, results: Dict):
        """Visualize how η affects sample appearance"""
        eta_values = list(results.keys())
        n_etas = len(eta_values)
        
        fig, axes = plt.subplots(n_etas, 3, figsize=(12, 3*n_etas))
        if n_etas == 1:
            axes = axes.reshape(1, -1)
        
        for i, eta in enumerate(eta_values):
            samples_list = results[eta]
            
            for j, samples in enumerate(samples_list):
                axes[i, j].scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), alpha=0.6, s=20)
                if j == 0:
                    axes[i, j].set_ylabel(f'η = {eta}')
                axes[i, j].set_title(f'Run {j+1}')
                axes[i, j].grid(True, alpha=0.3)
                axes[i, j].set_xlim(-6, 6)
                axes[i, j].set_ylim(-6, 6)
        
        plt.suptitle('DDIM Stochasticity: Effect of η Parameter')
        plt.tight_layout()
        plt.show()
    
    def analyze_diversity_vs_eta(self, results: Dict):
        """Quantify sample diversity for different η values"""
        eta_values = list(results.keys())
        diversities = []
        
        for eta in eta_values:
            samples_list = results[eta]
            
            # Compute pairwise differences between runs
            total_diff = 0
            num_pairs = 0
            
            for i in range(len(samples_list)):
                for j in range(i+1, len(samples_list)):
                    diff = torch.norm(samples_list[i] - samples_list[j]).item()
                    total_diff += diff
                    num_pairs += 1
            
            avg_diversity = total_diff / num_pairs if num_pairs > 0 else 0
            diversities.append(avg_diversity)
            print(f"η = {eta}: Average diversity = {avg_diversity:.3f}")
        
        # Plot diversity vs eta
        plt.figure(figsize=(8, 6))
        plt.plot(eta_values, diversities, 'go-', linewidth=2, markersize=8)
        plt.xlabel('η (Stochasticity Parameter)')
        plt.ylabel('Sample Diversity')
        plt.title('Sample Diversity vs η Parameter')
        plt.grid(True, alpha=0.3)
        plt.show()
        
        print("\nKey insights:")
        print("• η = 0: Deterministic sampling (identical samples)")
        print("• η > 0: Stochastic sampling (diverse samples)")
        print("• η = 1: Similar to DDPM stochasticity")

# Test stochasticity controller (uncomment after implementing TODOs)
# stochasticity_controller = StochasticityController(ddim_sampler)
# stochasticity_controller.analyze_eta_effects((25, 2), eta_values=[0.0, 0.1, 0.5, 1.0], num_steps=20)

---

## Part 5: DDPM vs DDIM Comparison (15 minutes)

### Task 5.1: Comprehensive Comparison

**Your Mission**: Compare DDPM and DDIM across multiple dimensions: speed, quality, and diversity.

In [None]:
class SamplerComparison:
    """
    Comprehensive comparison between DDPM and DDIM sampling approaches.
    """
    
    def __init__(self, ddpm_sampler: DDPMSampler, ddim_sampler: DDIMSampler):
        self.ddpm_sampler = ddpm_sampler
        self.ddim_sampler = ddim_sampler
        
    def speed_comparison(self, shape: Tuple[int, ...], ddim_steps: List[int]):
        """
        Compare sampling speeds between DDPM and DDIM
        """
        print("=== Speed Comparison: DDPM vs DDIM ===\n")
        
        results = {}
        
        # Test DDPM (full steps)
        print("Testing DDPM (full schedule)...")
        start_time = time.time()
        try:
            ddpm_samples = self.ddpm_sampler.sample(shape)
            ddpm_time = time.time() - start_time
            if ddmp_samples is not None:
                results['DDPM'] = {'samples': ddpm_samples, 'time': ddpm_time, 'steps': self.ddpm_sampler.config.T}
                print(f"  ✓ DDPM: {ddpm_time:.3f}s with {self.ddpm_sampler.config.T} steps")
            else:
                print("  ❌ Implement DDPM sample() method")
        except Exception as e:
            print(f"  ❌ DDPM error: {e}")
        
        # Test DDIM with different step counts
        for num_steps in ddim_steps:
            print(f"Testing DDIM ({num_steps} steps)...")
            timesteps = self.ddim_sampler.create_timestep_schedule(num_steps)
            
            start_time = time.time()
            try:
                ddim_samples = self.ddim_sampler.sample(shape, timesteps, eta=0.0)
                ddim_time = time.time() - start_time
                if ddim_samples is not None:
                    results[f'DDIM_{num_steps}'] = {'samples': ddim_samples, 'time': ddim_time, 'steps': num_steps}
                    speedup = ddpm_time / ddim_time if 'DDPM' in results else 1.0
                    print(f"  ✓ DDIM: {ddim_time:.3f}s with {num_steps} steps (Speedup: {speedup:.1f}x)")
                else:
                    print("  ❌ Implement DDIM sample() method")
            except Exception as e:
                print(f"  ❌ DDIM error: {e}")
        
        if results:
            self.visualize_speed_comparison(results)
        
        return results
    
    def quality_comparison(self, shape: Tuple[int, ...], reference_data: torch.Tensor):
        """
        Compare sample quality between different methods
        """
        print("=== Quality Comparison ===\n")
        
        methods = {
            'DDPM': lambda: self.ddpm_sampler.sample(shape),
            'DDIM_50': lambda: self.ddim_sampler.sample(shape, 
                                 self.ddim_sampler.create_timestep_schedule(50), eta=0.0),
            'DDIM_20': lambda: self.ddim_sampler.sample(shape,
                                 self.ddim_sampler.create_timestep_schedule(20), eta=0.0),
            'DDIM_10': lambda: self.ddim_sampler.sample(shape,
                                 self.ddim_sampler.create_timestep_schedule(10), eta=0.0)
        }
        
        results = {}
        
        for method_name, sample_fn in methods.items():
            try:
                samples = sample_fn()
                if samples is not None:
                    # Simple quality metric: how close to reference distribution
                    quality_score = self.compute_simple_quality_metric(samples, reference_data)
                    results[method_name] = {'samples': samples, 'quality': quality_score}
                    print(f"{method_name}: Quality score = {quality_score:.3f}")
                else:
                    print(f"{method_name}: ❌ Implementation needed")
            except Exception as e:
                print(f"{method_name}: ❌ Error: {e}")
        
        if results:
            self.visualize_quality_comparison(results, reference_data)
        
        return results
    
    def compute_simple_quality_metric(self, samples: torch.Tensor, reference: torch.Tensor) -> float:
        """
        Simple quality metric based on distribution moments
        """
        # Compare means and standard deviations
        sample_mean = samples.mean(dim=0)
        sample_std = samples.std(dim=0)
        ref_mean = reference.mean(dim=0)
        ref_std = reference.std(dim=0)
        
        mean_diff = torch.norm(sample_mean - ref_mean).item()
        std_diff = torch.norm(sample_std - ref_std).item()
        
        # Lower is better (closer to reference)
        quality_score = -(mean_diff + std_diff)
        return quality_score
    
    def diversity_analysis(self, shape: Tuple[int, ...], n_runs: int = 5):
        """
        Analyze sample diversity for DDPM vs DDIM
        """
        print("=== Diversity Analysis ===\n")
        
        methods = {
            'DDPM': lambda: self.ddpm_sampler.sample(shape),
            'DDIM_det': lambda: self.ddim_sampler.sample(shape, 
                                  self.ddim_sampler.create_timestep_schedule(20), eta=0.0),
            'DDIM_stoch': lambda: self.ddim_sampler.sample(shape,
                                    self.ddim_sampler.create_timestep_schedule(20), eta=0.5)
        }
        
        for method_name, sample_fn in methods.items():
            print(f"Testing {method_name} diversity...")
            
            samples_list = []
            for run in range(n_runs):
                try:
                    samples = sample_fn()
                    if samples is not None:
                        samples_list.append(samples)
                    else:
                        print(f"  ❌ Implementation needed")
                        break
                except Exception as e:
                    print(f"  ❌ Error: {e}")
                    break
            
            if len(samples_list) == n_runs:
                # Compute diversity
                total_diff = 0
                num_pairs = 0
                for i in range(n_runs):
                    for j in range(i+1, n_runs):
                        diff = torch.norm(samples_list[i] - samples_list[j]).item()
                        total_diff += diff
                        num_pairs += 1
                
                avg_diversity = total_diff / num_pairs
                print(f"  Average diversity: {avg_diversity:.3f}")
                
                # Visualize first few runs
                self.visualize_diversity_runs(samples_list[:3], method_name)
    
    def visualize_speed_comparison(self, results: Dict):
        """Visualize speed comparison results"""
        methods = list(results.keys())
        times = [results[m]['time'] for m in methods]
        steps = [results[m]['steps'] for m in methods]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Time comparison
        bars1 = ax1.bar(methods, times, color=['blue', 'red', 'orange', 'green'][:len(methods)])
        ax1.set_ylabel('Sampling Time (s)')
        ax1.set_title('Sampling Speed Comparison')
        ax1.tick_params(axis='x', rotation=45)
        
        # Add time labels on bars
        for bar, time in zip(bars1, times):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 0.001,
                    f'{time:.3f}s', ha='center', va='bottom')
        
        # Steps comparison
        bars2 = ax2.bar(methods, steps, color=['blue', 'red', 'orange', 'green'][:len(methods)])
        ax2.set_ylabel('Number of Steps')
        ax2.set_title('Sampling Steps Comparison')
        ax2.tick_params(axis='x', rotation=45)
        
        # Add steps labels on bars
        for bar, step in zip(bars2, steps):
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height + 1,
                    f'{step}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
    
    def visualize_quality_comparison(self, results: Dict, reference_data: torch.Tensor):
        """Visualize quality comparison results"""
        n_methods = len(results)
        fig, axes = plt.subplots(1, n_methods + 1, figsize=(4*(n_methods+1), 4))
        
        # Plot reference data
        axes[0].scatter(reference_data[:, 0].cpu(), reference_data[:, 1].cpu(), 
                       alpha=0.6, s=20, color='black')
        axes[0].set_title('Reference Data')
        axes[0].grid(True, alpha=0.3)
        axes[0].set_xlim(-6, 6)
        axes[0].set_ylim(-6, 6)
        
        # Plot generated samples
        for i, (method, result) in enumerate(results.items()):
            samples = result['samples']
            quality = result['quality']
            
            axes[i+1].scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), alpha=0.6, s=20)
            axes[i+1].set_title(f'{method}\nQuality: {quality:.3f}')
            axes[i+1].grid(True, alpha=0.3)
            axes[i+1].set_xlim(-6, 6)
            axes[i+1].set_ylim(-6, 6)
        
        plt.suptitle('Sample Quality Comparison')
        plt.tight_layout()
        plt.show()
    
    def visualize_diversity_runs(self, samples_list: List[torch.Tensor], method_name: str):
        """Visualize multiple runs for diversity analysis"""
        n_runs = len(samples_list)
        fig, axes = plt.subplots(1, n_runs, figsize=(4*n_runs, 4))
        
        if n_runs == 1:
            axes = [axes]
        
        for i, samples in enumerate(samples_list):
            axes[i].scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), alpha=0.6, s=20)
            axes[i].set_title(f'Run {i+1}')
            axes[i].grid(True, alpha=0.3)
            axes[i].set_xlim(-6, 6)
            axes[i].set_ylim(-6, 6)
        
        plt.suptitle(f'{method_name}: Sample Diversity')
        plt.tight_layout()
        plt.show()

# Test comprehensive comparison (uncomment after implementing all samplers)
# comparison = SamplerComparison(ddpm_sampler, ddim_sampler)

# # Speed comparison
# speed_results = comparison.speed_comparison((30, 2), ddim_steps=[50, 25, 10])

# # Quality comparison  
# quality_results = comparison.quality_comparison((30, 2), test_data_2d)

# # Diversity analysis
# comparison.diversity_analysis((25, 2), n_runs=3)

---

## Part 6: Advanced Sampling Optimizations (10 minutes)

### Task 6.1: Implement Practical Optimizations

**Your Mission**: Implement real-world optimizations for faster and more efficient sampling.

In [None]:
class SamplingOptimizer:
    """
    Implement practical optimizations for diffusion sampling.
    These techniques are essential for real-world deployment.
    """
    
    def __init__(self, model: nn.Module, config: SamplingConfig):
        self.model = model
        self.config = config
        
    def cached_noise_schedule_sampling(self, ddim_sampler: DDIMSampler, 
                                     shape: Tuple[int, ...], num_steps: int) -> torch.Tensor:
        """
        TODO: Implement optimized sampling with pre-computed coefficients
        
        Pre-compute all noise schedule coefficients to avoid repeated calculations.
        This can provide 10-20% speedup in practice.
        
        Args:
            ddim_sampler: DDIM sampler to optimize
            shape: Shape of samples to generate
            num_steps: Number of sampling steps
            
        Returns:
            Generated samples with optimized computation
        """
        # TODO: Your implementation here
        # Step 1: Pre-compute all α, ᾱ coefficients for the timestep schedule
        # Step 2: Create optimized sampling loop using cached values
        # Step 3: Avoid repeated tensor operations and indexing
        # Step 4: Return samples with improved efficiency
        pass
    
    def mixed_precision_sampling(self, sampler, shape: Tuple[int, ...]) -> torch.Tensor:
        """
        Implement mixed precision sampling for speed/memory optimization
        """
        print("=== Mixed Precision Sampling ===\n")
        
        # Use autocast for automatic mixed precision
        with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', enabled=True):
            start_time = time.time()
            try:
                samples = sampler.sample(shape)
                mixed_precision_time = time.time() - start_time
                print(f"Mixed precision sampling time: {mixed_precision_time:.3f}s")
                return samples
            except Exception as e:
                print(f"Mixed precision error: {e}")
                return None
    
    def benchmark_optimizations(self, ddim_sampler: DDIMSampler, shape: Tuple[int, ...]):
        """
        Benchmark different optimization techniques
        """
        print("=== Optimization Benchmarks ===\n")
        
        num_steps = 20
        results = {}
        
        # Baseline DDIM
        print("1. Baseline DDIM...")
        timesteps = ddim_sampler.create_timestep_schedule(num_steps)
        start_time = time.time()
        try:
            baseline_samples = ddim_sampler.sample(shape, timesteps, eta=0.0)
            baseline_time = time.time() - start_time
            results['Baseline'] = {'time': baseline_time, 'samples': baseline_samples}
            print(f"   Time: {baseline_time:.3f}s")
        except:
            print("   ❌ Implement DDIM sample() first")
            return
        
        # Cached coefficients
        print("2. Cached coefficients...")
        start_time = time.time()
        try:
            cached_samples = self.cached_noise_schedule_sampling(ddim_sampler, shape, num_steps)
            cached_time = time.time() - start_time
            if cached_samples is not None:
                results['Cached'] = {'time': cached_time, 'samples': cached_samples}
                speedup = baseline_time / cached_time
                print(f"   Time: {cached_time:.3f}s (Speedup: {speedup:.2f}x)")
            else:
                print("   ❌ Implement cached sampling")
        except Exception as e:
            print(f"   ❌ Error: {e}")
        
        # Mixed precision
        print("3. Mixed precision...")
        mixed_samples = self.mixed_precision_sampling(
            lambda shape: ddim_sampler.sample(shape, timesteps, eta=0.0), shape)
        
        # Memory usage analysis
        print("4. Memory analysis...")
        self.analyze_memory_usage(ddim_sampler, shape, num_steps)
        
        if len(results) > 1:
            self.visualize_optimization_results(results)
    
    def analyze_memory_usage(self, ddim_sampler: DDIMSampler, shape: Tuple[int, ...], num_steps: int):
        """
        Analyze memory usage during sampling
        """
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            initial_memory = torch.cuda.memory_allocated()
            
            timesteps = ddim_sampler.create_timestep_schedule(num_steps)
            
            try:
                samples = ddim_sampler.sample(shape, timesteps, eta=0.0)
                peak_memory = torch.cuda.max_memory_allocated()
                
                memory_usage = (peak_memory - initial_memory) / 1024**2  # MB
                print(f"   Peak memory usage: {memory_usage:.1f} MB")
                
                torch.cuda.empty_cache()
            except:
                print("   ❌ Cannot analyze memory - implement DDIM first")
        else:
            print("   CUDA not available for memory analysis")
    
    def visualize_optimization_results(self, results: Dict):
        """Visualize optimization benchmark results"""
        methods = list(results.keys())
        times = [results[m]['time'] for m in methods]
        
        plt.figure(figsize=(10, 6))
        bars = plt.bar(methods, times, color=['blue', 'orange', 'green'][:len(methods)])
        plt.ylabel('Sampling Time (s)')
        plt.title('Sampling Optimization Benchmarks')
        
        # Add speedup labels
        baseline_time = times[0] if 'Baseline' in methods else times[0]
        for bar, time in zip(bars, times):
            speedup = baseline_time / time
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,
                    f'{time:.3f}s\n({speedup:.2f}x)', ha='center', va='bottom')
        
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

# Test sampling optimizer (uncomment after implementing samplers)
# optimizer = SamplingOptimizer(pretrained_model, config)
# optimizer.benchmark_optimizations(ddim_sampler, (25, 2))

---

## Part 7: Real-World Deployment Considerations (10 minutes)

### Task 7.1: Production-Ready Sampling

**Your Mission**: Implement considerations for deploying diffusion models in production environments.

In [None]:
class ProductionSampler:
    """
    Production-ready sampling implementation with practical considerations.
    """
    
    def __init__(self, model: nn.Module, config: SamplingConfig):
        self.model = model
        self.config = config
        self.model.eval()  # Ensure model is in eval mode
        
    def robust_sampling_with_fallback(self, shape: Tuple[int, ...], 
                                    preferred_steps: int = 20, 
                                    fallback_steps: int = 50,
                                    max_retries: int = 3) -> torch.Tensor:
        """
        TODO: Implement robust sampling with error handling and fallback
        
        Production systems need to handle failures gracefully.
        
        Args:
            shape: Shape of samples to generate
            preferred_steps: Preferred number of steps (fast)
            fallback_steps: Fallback number of steps (slower but more reliable)
            max_retries: Maximum number of retry attempts
            
        Returns:
            Generated samples with fallback handling
        """
        # TODO: Your implementation here
        # Step 1: Try preferred (fast) sampling first
        # Step 2: If failed, try fallback (slower) sampling
        # Step 3: Implement retry logic with exponential backoff
        # Step 4: Return samples or raise informative error after max retries
        pass
    
    def batch_sampling_with_memory_management(self, total_samples: int, 
                                            batch_size: int = 16,
                                            num_steps: int = 20) -> torch.Tensor:
        """
        TODO: Implement memory-efficient batch sampling
        
        Generate large numbers of samples without running out of memory.
        
        Args:
            total_samples: Total number of samples to generate
            batch_size: Number of samples per batch
            num_steps: Number of sampling steps
            
        Returns:
            All generated samples concatenated
        """
        # TODO: Your implementation here
        # Step 1: Calculate number of batches needed
        # Step 2: For each batch: generate samples and clear memory
        # Step 3: Concatenate results efficiently
        # Step 4: Monitor memory usage and adjust batch size if needed
        pass
    
    def adaptive_quality_sampling(self, shape: Tuple[int, ...], 
                                 target_quality: float = 0.8,
                                 max_steps: int = 50) -> Dict:
        """
        Implement adaptive sampling that adjusts steps based on quality
        """
        print("=== Adaptive Quality Sampling ===\n")
        
        # Try different step counts and measure quality
        step_counts = [10, 15, 20, 30, 40, 50]
        results = []
        
        for steps in step_counts:
            if steps > max_steps:
                break
                
            print(f"Testing {steps} steps...")
            try:
                # This would use your DDIM implementation
                # timesteps = create_timestep_schedule(steps)
                # samples = ddim_sampler.sample(shape, timesteps, eta=0.0)
                
                # Placeholder for quality assessment
                # In practice, you'd use metrics like FID, IS, or perceptual similarity
                estimated_quality = min(0.9, 0.3 + steps * 0.015)  # Mock quality function
                
                results.append({
                    'steps': steps,
                    'quality': estimated_quality,
                    'time': steps * 0.02  # Mock timing
                })
                
                print(f"  Quality: {estimated_quality:.3f}")
                
                if estimated_quality >= target_quality:
                    print(f"✓ Target quality {target_quality} reached with {steps} steps")
                    break
                    
            except Exception as e:
                print(f"  ❌ Error with {steps} steps: {e}")
        
        self.visualize_adaptive_results(results, target_quality)
        return results
    
    def sampling_health_check(self) -> Dict[str, bool]:
        """
        Perform health checks on the sampling system
        """
        print("=== Sampling System Health Check ===\n")
        
        health_status = {}
        
        # Check 1: Model is in eval mode
        health_status['model_eval_mode'] = not self.model.training
        print(f"Model in eval mode: {'✓' if health_status['model_eval_mode'] else '❌'}")
        
        # Check 2: Device availability
        health_status['device_available'] = torch.cuda.is_available() if 'cuda' in str(device) else True
        print(f"Device available: {'✓' if health_status['device_available'] else '❌'}")
        
        # Check 3: Memory availability
        if torch.cuda.is_available():
            available_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
            health_status['sufficient_memory'] = available_memory > 100 * 1024**2  # 100MB threshold
            print(f"Sufficient memory: {'✓' if health_status['sufficient_memory'] else '❌'} ({available_memory/1024**2:.0f}MB available)")
        else:
            health_status['sufficient_memory'] = True
            print("Memory check: ✓ (CPU mode)")
        
        # Check 4: Model prediction test
        try:
            test_input = torch.randn(1, 2).to(device)
            test_t = torch.tensor(10).to(device)
            with torch.no_grad():
                test_output = self.model(test_input, test_t)
            health_status['model_functional'] = test_output.shape == test_input.shape
            print(f"Model functional: {'✓' if health_status['model_functional'] else '❌'}")
        except Exception as e:
            health_status['model_functional'] = False
            print(f"Model functional: ❌ ({e})")
        
        # Check 5: Noise schedule validity
        health_status['noise_schedule_valid'] = (
            len(self.config.betas) == self.config.T and
            torch.all(self.config.betas > 0) and
            torch.all(self.config.betas < 1)
        )
        print(f"Noise schedule valid: {'✓' if health_status['noise_schedule_valid'] else '❌'}")
        
        overall_health = all(health_status.values())
        print(f"\nOverall system health: {'✓ HEALTHY' if overall_health else '❌ ISSUES DETECTED'}")
        
        return health_status
    
    def visualize_adaptive_results(self, results: List[Dict], target_quality: float):
        """Visualize adaptive quality results"""
        if not results:
            return
            
        steps = [r['steps'] for r in results]
        qualities = [r['quality'] for r in results]
        times = [r['time'] for r in results]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Quality vs steps
        ax1.plot(steps, qualities, 'bo-', linewidth=2, markersize=8)
        ax1.axhline(y=target_quality, color='red', linestyle='--', alpha=0.7, label=f'Target: {target_quality}')
        ax1.set_xlabel('Number of Steps')
        ax1.set_ylabel('Quality Score')
        ax1.set_title('Adaptive Quality Control')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Time vs quality
        ax2.plot(times, qualities, 'ro-', linewidth=2, markersize=8)
        ax2.axhline(y=target_quality, color='red', linestyle='--', alpha=0.7, label=f'Target: {target_quality}')
        ax2.set_xlabel('Sampling Time (s)')
        ax2.set_ylabel('Quality Score')
        ax2.set_title('Time vs Quality Trade-off')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Test production sampler
production_sampler = ProductionSampler(pretrained_model, config)

# Health check
health_status = production_sampler.sampling_health_check()

# Adaptive quality demonstration
adaptive_results = production_sampler.adaptive_quality_sampling((20, 2), target_quality=0.7)

---

## Part 8: Integration and Final Validation (5 minutes)

### Task 8.1: Complete System Integration

**Your Mission**: Integrate all components and validate the complete sampling system.

In [None]:
def comprehensive_sampling_validation():
    """
    Final validation of the complete sampling implementation
    """
    print("=== Comprehensive Sampling System Validation ===\n")
    
    validation_results = {
        'ddpm_implemented': False,
        'ddim_implemented': False,
        'optimization_working': False,
        'production_ready': False
    }
    
    # Test 1: DDPM Implementation
    print("1. Validating DDPM Implementation...")
    try:
        ddpm_sampler = DDPMSampler(pretrained_model, config)
        ddpm_samples = ddpm_sampler.sample((10, 2))
        if ddmp_samples is not None:
            validation_results['ddpm_implemented'] = True
            print("   ✓ DDPM sampling functional")
        else:
            print("   ❌ DDPM sample() returns None")
    except Exception as e:
        print(f"   ❌ DDPM error: {e}")
    
    # Test 2: DDIM Implementation
    print("2. Validating DDIM Implementation...")
    try:
        ddim_sampler = DDIMSampler(pretrained_model, config)
        timesteps = ddim_sampler.create_timestep_schedule(20)
        ddim_samples = ddim_sampler.sample((10, 2), timesteps, eta=0.0)
        if ddim_samples is not None:
            validation_results['ddim_implemented'] = True
            print("   ✓ DDIM sampling functional")
        else:
            print("   ❌ DDIM sample() returns None")
    except Exception as e:
        print(f"   ❌ DDIM error: {e}")
    
    # Test 3: Optimization Features
    print("3. Validating Optimization Features...")
    try:
        optimizer = SamplingOptimizer(pretrained_model, config)
        # Test cached sampling if DDIM works
        if validation_results['ddim_implemented']:
            cached_samples = optimizer.cached_noise_schedule_sampling(ddim_sampler, (5, 2), 10)
            if cached_samples is not None:
                validation_results['optimization_working'] = True
                print("   ✓ Optimizations functional")
            else:
                print("   ❌ Cached sampling not implemented")
        else:
            print("   ❌ Cannot test optimizations without DDIM")
    except Exception as e:
        print(f"   ❌ Optimization error: {e}")
    
    # Test 4: Production Readiness
    print("4. Validating Production Features...")
    try:
        production_sampler = ProductionSampler(pretrained_model, config)
        health_status = production_sampler.sampling_health_check()
        if all(health_status.values()):
            validation_results['production_ready'] = True
            print("   ✓ Production features functional")
        else:
            print("   ❌ Production health check failed")
    except Exception as e:
        print(f"   ❌ Production error: {e}")
    
    # Overall assessment
    print("\n" + "="*50)
    print("FINAL VALIDATION RESULTS:")
    print("="*50)
    
    for component, status in validation_results.items():
        status_str = "✓ PASS" if status else "❌ FAIL"
        print(f"{component.replace('_', ' ').title()}: {status_str}")
    
    overall_pass = sum(validation_results.values()) >= 2  # At least 2 components working
    print(f"\nOverall System Status: {'✓ FUNCTIONAL' if overall_pass else '❌ NEEDS WORK'}")
    
    if overall_pass:
        print("\n🎉 Congratulations! Your sampling system is working!")
        print("You've successfully implemented the core of modern diffusion sampling.")
    else:
        print("\n🔧 Keep working on the TODO implementations.")
        print("Focus on DDPM and DDIM sample() methods first.")
    
    return validation_results

def demonstrate_complete_sampling_pipeline():
    """
    Demonstrate the complete sampling pipeline from theory to practice
    """
    print("=== Complete Sampling Pipeline Demonstration ===\n")
    
    pipeline_stages = [
        "1. 🧠 Trained Model: Noise predictor ε_θ(x_t, t)",
        "2. 📐 Mathematical Framework: ELBO and reverse process theory", 
        "3. 🎲 DDPM Sampling: Stochastic reverse process (high quality)",
        "4. ⚡ DDIM Sampling: Deterministic acceleration (fast generation)",
        "5. 🎛️  Controllable Stochasticity: η parameter for speed/diversity tradeoff",
        "6. 🚀 Optimizations: Caching, mixed precision, memory management",
        "7. 🏭 Production Deployment: Error handling, health checks, batching"
    ]
    
    print("Your journey through diffusion sampling:")
    for stage in pipeline_stages:
        print(f"  {stage}")
    
    print(f"\n💡 Key insights achieved:")
    print(f"   • DDPM: Faithful to theory, slow but high quality")
    print(f"   • DDIM: Clever mathematical insight enables 10-50x speedup")
    print(f"   • η parameter: Smooth interpolation between deterministic and stochastic")
    print(f"   • Production: Real systems need robustness and optimization")
    
    print(f"\n🌟 What this enables:")
    print(f"   • Real-time creative applications") 
    print(f"   • High-resolution image generation")
    print(f"   • Interactive AI art tools")
    print(f"   • Large-scale content creation")
    
    # Create pipeline visualization
    fig, ax = plt.subplots(1, 1, figsize=(14, 8))
    
    stages = ["Theory", "DDPM", "DDIM", "Optimization", "Production"]
    stage_colors = ['lightblue', 'blue', 'red', 'orange', 'green']
    
    # Draw pipeline flow
    y_pos = 0.5
    stage_width = 0.15
    
    for i, (stage, color) in enumerate(zip(stages, stage_colors)):
        x_pos = 0.1 + i * 0.2
        
        # Stage box
        ax.add_patch(plt.Rectangle((x_pos - stage_width/2, y_pos - 0.1), 
                                  stage_width, 0.2, 
                                  facecolor=color, alpha=0.7, edgecolor='black'))
        ax.text(x_pos, y_pos, stage, ha='center', va='center', 
               fontsize=10, weight='bold')
        
        # Arrow to next stage
        if i < len(stages) - 1:
            ax.arrow(x_pos + stage_width/2, y_pos, 
                    0.2 - stage_width, 0, 
                    head_width=0.03, head_length=0.02, 
                    fc='gray', ec='gray')
    
    # Add benefits annotations
    benefits = [
        (0.1, 0.3, "Mathematical\nFoundation"),
        (0.3, 0.7, "High Quality\nGeneration"),
        (0.5, 0.3, "Fast Sampling\n10-50x speedup"),
        (0.7, 0.7, "Memory Efficient\nCached Operations"),
        (0.9, 0.3, "Robust\nDeployment")
    ]
    
    for x, y, text in benefits:
        ax.text(x, y, text, ha='center', va='center', fontsize=9,
               bbox=dict(boxstyle="round,pad=0.3", facecolor='lightyellow', alpha=0.8))
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('Diffusion Sampling: From Theory to Production', fontsize=16, weight='bold')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Run final validation and demonstration
validation_results = comprehensive_sampling_validation()
demonstrate_complete_sampling_pipeline()

---

## Implementation Checklist

### Core Sampling Functions (Students Implement):

**✅ Essential TODOs:**
- [ ] `analyze_noise_prediction_quality()` - Validate model behavior
- [ ] `add_noise()` - Forward process implementation
- [ ] `compute_posterior_mean()` - DDPM posterior mean calculation
- [ ] `compute_posterior_variance()` - DDPM posterior variance calculation
- [ ] `ddpm_step()` - Single DDPM sampling step
- [ ] `sample()` (DDPM) - Complete DDPM sampling algorithm
- [ ] `predict_x0_from_eps()` - Clean image prediction from noise
- [ ] `predict_eps_from_x0()` - Noise prediction from clean image
- [ ] `ddim_step()` - Single DDIM sampling step with η control
- [ ] `sample()` (DDIM) - Complete DDIM sampling with step skipping
- [ ] `compute_stochastic_variance()` - η parameter variance computation
- [ ] `cached_noise_schedule_sampling()` - Performance optimization
- [ ] `robust_sampling_with_fallback()` - Production error handling
- [ ] `batch_sampling_with_memory_management()` - Memory-efficient sampling

**✅ Provided Starter Code:**
- [ ] All visualization and analysis functions
- [ ] Model architecture and configuration setup
- [ ] Comparison and benchmarking frameworks
- [ ] Health checking and validation systems
- [ ] Complete testing and demonstration pipeline

---

## Submission Requirements

### What to Submit

Submit your completed Jupyter notebook (.ipynb file) with:

**✅ DDPM Implementation:**
- Complete stochastic sampling algorithm with proper noise injection
- Posterior mean and variance calculations
- Validation of sampling consistency and quality

**✅ DDIM Implementation:**
- Deterministic sampling with step skipping capability
- Clean image prediction and noise schedule reconstruction
- η parameter implementation for controllable stochasticity

**✅ Performance Analysis:**
- Speed comparisons between DDPM and DDIM
- Quality vs speed trade-off analysis
- Step reduction studies and optimal scheduling

**✅ Optimization Techniques:**
- Cached coefficient computation for efficiency
- Memory management for large-scale sampling
- Production-ready error handling and fallback mechanisms

**✅ Comprehensive Evaluation:**
- Validation of all sampling components
- Comparison of different sampling strategies
- Analysis of practical deployment considerations

**✅ Documentation and Insights:**
- Clear explanations of implementation choices
- Discussion of trade-offs and practical considerations
- Connection between mathematical theory and implementation details

---

## Quick Reference: Key Mathematical Formulas

### For Implementation Reference:

**DDPM Posterior Mean:**

In [None]:
# μ_θ(x_t, t) = (1/√α_t) * (x_t - (1-α_t)/√(1-ᾱ_t) * ε_θ(x_t, t))
alpha_t = config.alphas[t]
alpha_cumprod_t = config.alphas_cumprod[t]
coeff_1 = 1.0 / torch.sqrt(alpha_t)
coeff_2 = (1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)
posterior_mean = coeff_1 * (x_t - coeff_2 * predicted_noise)

**DDPM Posterior Variance:**

In [None]:
# σ̃²_t = β_t * (1-ᾱ_{t-1})/(1-ᾱ_t)
if t == 0:
    posterior_variance = 0
else:
    posterior_variance = (config.betas[t] * (1 - config.alphas_cumprod[t-1]) / 
                         (1 - config.alphas_cumprod[t]))

**DDIM Clean Image Prediction:**

In [None]:
# x̂_0 = (x_t - √(1-ᾱ_t) * ε) / √ᾱ_t
alpha_cumprod_t = config.alphas_cumprod[t]
predicted_x0 = (x_t - torch.sqrt(1 - alpha_cumprod_t) * eps) / torch.sqrt(alpha_cumprod_t)

**DDIM Deterministic Update:**

In [None]:
# x_s = √ᾱ_s * x̂_0 + √(1-ᾱ_s) * ε̂
alpha_cumprod_s = config.alphas_cumprod[s]
x_s = (torch.sqrt(alpha_cumprod_s) * predicted_x0 + 
       torch.sqrt(1 - alpha_cumprod_s) * predicted_eps)

**η-Controlled Stochastic Variance:**

In [None]:
# σ_t^2 = η^2 * β̃_{t→s} where β̃_{t→s} = (1-ᾱ_s)/(1-ᾱ_t) * (1 - ᾱ_t/ᾱ_s)
alpha_cumprod_t = config.alphas_cumprod[t]
alpha_cumprod_s = config.alphas_cumprod[s]
beta_tilde = ((1 - alpha_cumprod_s) / (1 - alpha_cumprod_t) * 
              (1 - alpha_cumprod_t / alpha_cumprod_s))
stochastic_variance = eta**2 * beta_tilde

---

## Common Implementation Issues & Solutions

### Debugging Tips:

**Sampling Divergence:**
- Check that model is in `.eval()` mode during sampling
- Ensure noise schedule coefficients are computed correctly
- Verify timestep indexing (0-based vs 1-based)
- Add gradient checkpointing if running out of memory

**DDIM vs DDPM Consistency:**
- When η=1, DDIM should behave similarly to DDPM
- Test with same random seeds for reproducibility
- Verify that timestep schedules are correctly implemented
- Check that final step (t=0) doesn't add noise

**Performance Issues:**
- Pre-compute noise schedule coefficients outside sampling loop
- Use in-place operations where possible
- Implement gradient checkpointing for memory efficiency
- Consider mixed precision for speed without quality loss

**Quality Degradation:**
- Too few DDIM steps can cause artifacts
- Check that noise prediction model is properly trained
- Verify that timestep embeddings are correctly normalized
- Ensure numerical stability in coefficient computations

---

