# Lab 4: ELBO for Diffusion Models - Learning to Reverse Chaos
**Course: Diffusion Models: Theory and Applications**  
**Duration: 90 minutes**  
**Team Size: 2 students (same teams from Labs 1-3)**

---

## Learning Objectives
By the end of this lab, students will be able to:
1. **Implement** the complete ELBO derivation for diffusion models from first principles
2. **Build** the three-forces decomposition: reconstruction, prior matching, and denoising
3. **Create** the tractable reverse distribution using Bayes' rule
4. **Construct** the noise prediction reparameterization
5. **Connect** complex ELBO theory to simple practical training algorithms
6. **Demonstrate** how mathematical elegance enables state-of-the-art generation

---

## Lab Setup and Diffusion Framework

### Part 1: Team Setup & Sequential Latent Mission (10 minutes)

In [None]:
# Diffusion ELBO implementation setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
from torch.distributions import Normal, MultivariateNormal
from typing import Tuple, Dict, List
import time

# Set seeds for reproducible mathematics
torch.manual_seed(42)
np.random.seed(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Diffusion mathematics on: {device}")

# Create test data for diffusion experiments
def create_diffusion_test_data(n_samples: int = 300) -> torch.Tensor:
    """Create 2D data for testing diffusion ELBO implementations"""
    # Create a more complex distribution - mixture of Gaussians
    t = torch.linspace(0, 2*math.pi, n_samples//3)
    
    # Three clusters in a triangle formation
    cluster1 = torch.stack([2*torch.cos(t) + 0.3*torch.randn(n_samples//3), 
                           2*torch.sin(t) + 0.3*torch.randn(n_samples//3)], dim=1)
    cluster2 = torch.stack([2*torch.cos(t + 2*math.pi/3) + 0.3*torch.randn(n_samples//3),
                           2*torch.sin(t + 2*math.pi/3) + 0.3*torch.randn(n_samples//3)], dim=1)
    cluster3 = torch.stack([2*torch.cos(t + 4*math.pi/3) + 0.3*torch.randn(n_samples//3),
                           2*torch.sin(t + 4*math.pi/3) + 0.3*torch.randn(n_samples//3)], dim=1)
    
    data = torch.cat([cluster1, cluster2, cluster3], dim=0)
    return data.to(device)

# Generate test data
test_data = create_diffusion_test_data(300)
print(f"Test data shape: {test_data.shape}")

# Visualize test data
plt.figure(figsize=(8, 6))
plt.scatter(test_data[:, 0].cpu(), test_data[:, 1].cpu(), alpha=0.7, s=30, c='blue')
plt.title('Test Data: Three-Cluster Distribution')
plt.xlabel('X')
plt.ylabel('Y')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.show()

# Define noise schedule for diffusion process
def create_noise_schedule(T: int = 100, beta_start: float = 1e-4, beta_end: float = 2e-2) -> Dict[str, torch.Tensor]:
    """
    Create linear noise schedule for diffusion process
    
    Returns:
        Dictionary containing β_t, α_t, and ᾱ_t sequences
    """
    # Linear schedule
    betas = torch.linspace(beta_start, beta_end, T).to(device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    
    return {
        'betas': betas,
        'alphas': alphas, 
        'alphas_cumprod': alphas_cumprod,
        'T': T
    }

# Create noise schedule
noise_schedule = create_noise_schedule(T=50)  # Smaller T for faster computation
T = noise_schedule['T']
print(f"Using T = {T} timesteps")

---

## Part 2: The Intractable Sequential Likelihood (15 minutes)

### Task 2.1: Experience the Marginal Likelihood Crisis

**Your Mission**: Implement the intractable marginal likelihood for sequential latents and see why direct optimization fails.

In [None]:
class SequentialLikelihoodDemo:
    """
    Demonstrate why direct likelihood computation fails for sequential generative models.
    You'll implement the mathematics to see the exponential complexity explosion.
    """
    
    def __init__(self, data_dim: int = 2, T: int = 50):
        self.data_dim = data_dim
        self.T = T
        self.noise_schedule = create_noise_schedule(T)
        
        # Simple reverse process model (will be learned)
        # For demo purposes, we'll use a simple linear model
        self.reverse_model = nn.ModuleList([
            nn.Linear(data_dim, data_dim) for _ in range(T)
        ]).to(device)
        
    def forward_step(self, x_prev: torch.Tensor, t: int) -> torch.Tensor:
        """
        TODO: Implement single forward diffusion step
        
        Apply q(x_t | x_{t-1}) = N(x_t; √α_t x_{t-1}, β_t I)
        
        Args:
            x_prev: Previous state (batch_size, data_dim)
            t: Timestep (0-indexed)
            
        Returns:
            x_t: Next state after adding noise
        """
        # TODO: Your implementation here
        # Step 1: Get α_t and β_t from noise schedule
        # Step 2: Compute mean: √α_t * x_{t-1}
        # Step 3: Sample noise with variance β_t
        # Step 4: Return x_t = mean + noise
        pass
    
    def forward_trajectory(self, x0: torch.Tensor) -> List[torch.Tensor]:
        """
        TODO: Implement complete forward diffusion trajectory
        
        Generate x_1, x_2, ..., x_T from x_0 using the forward process
        
        Args:
            x0: Clean data (batch_size, data_dim)
            
        Returns:
            trajectory: List of states [x_0, x_1, ..., x_T]
        """
        # TODO: Your implementation here
        # Step 1: Initialize trajectory with x_0
        # Step 2: For each timestep t, apply forward_step
        # Step 3: Store each intermediate state
        # Step 4: Return complete trajectory
        pass
    
    def direct_jump_forward(self, x0: torch.Tensor, t: int) -> torch.Tensor:
        """
        TODO: Implement direct jump to timestep t
        
        Use the analytical form: q(x_t | x_0) = N(x_t; √ᾱ_t x_0, (1-ᾱ_t) I)
        
        This is much more efficient than sequential forward steps!
        
        Args:
            x0: Clean data (batch_size, data_dim)
            t: Target timestep
            
        Returns:
            x_t: State at timestep t
        """
        # TODO: Your implementation here
        # Step 1: Get ᾱ_t from noise schedule
        # Step 2: Compute mean: √ᾱ_t * x_0
        # Step 3: Compute variance: (1 - ᾱ_t)
        # Step 4: Sample from N(mean, variance * I)
        pass
    
    def marginal_likelihood_approximation(self, x0: torch.Tensor, n_trajectories: int = 100) -> float:
        """
        Attempt Monte Carlo approximation of p(x_0)
        
        This will demonstrate the computational intractability.
        
        p(x_0) = ∫ p(x_{0:T}) dx_{1:T}
               ≈ (1/K) Σ p(x_{0:T}^{(k)}) where x_{1:T}^{(k)} ~ q(x_{1:T}|x_0)
        """
        print(f"Attempting marginal likelihood with {n_trajectories} trajectories...")
        start_time = time.time()
        
        log_probs = []
        for k in range(n_trajectories):
            # Generate forward trajectory (after students implement forward_trajectory)
            try:
                trajectory = self.forward_trajectory(x0[:1])  # Single sample
                
                # Compute joint probability p(x_{0:T}) = p(x_T) * ∏ p(x_{t-1}|x_t)
                # For demo, we'll use simple Gaussian approximations
                log_prob = 0.0
                
                # Prior term: p(x_T) ≈ N(0, I) since x_T should be noise
                x_T = trajectory[-1]
                log_prob += -0.5 * (x_T**2).sum()
                
                # Reverse terms (using our simple model)
                for t in range(self.T, 0, -1):
                    x_t = trajectory[t]
                    x_t_minus_1 = trajectory[t-1]
                    
                    # Simple reverse model prediction
                    predicted_mean = self.reverse_model[t-1](x_t)
                    log_prob += -0.5 * ((x_t_minus_1 - predicted_mean)**2).sum()
                
                log_probs.append(log_prob.item())
            except:
                print("Forward trajectory not implemented yet")
                return 0.0
        
        computation_time = time.time() - start_time
        
        # Results analysis
        log_probs = np.array(log_probs)
        print(f"Computation time: {computation_time:.3f}s")
        print(f"Log probability estimates: mean={log_probs.mean():.3f}, std={log_probs.std():.3f}")
        print(f"Probability range: [{np.exp(log_probs.min()):.2e}, {np.exp(log_probs.max()):.2e}]")
        
        print("\nWhy this approach fails:")
        print("❌ Exponential growth with sequence length T")
        print("❌ High variance in estimates")
        print("❌ Requires learned reverse model (circular dependency)")
        print("❌ Completely impractical for realistic T (1000+ steps)")
        
        return log_probs.mean()
    
    def demonstrate_complexity_explosion(self):
        """Show how complexity grows with T"""
        print("=== Demonstrating Complexity Explosion ===\n")
        
        complexity_analysis = []
        for T_test in [5, 10, 20, 50]:
            print(f"Sequence length T = {T_test}:")
            print(f"  Number of random variables: {T_test + 1}")
            print(f"  Marginal integration dimensions: {T_test * self.data_dim}")
            print(f"  Forward process evaluations: {T_test}")
            print(f"  Reverse process evaluations: {T_test}")
            
            # Estimate computational cost (hypothetical)
            forward_cost = T_test
            reverse_cost = T_test
            integration_cost = (T_test * self.data_dim) ** 2  # Simplified estimate
            total_cost = forward_cost + reverse_cost + integration_cost
            
            complexity_analysis.append((T_test, total_cost))
            print(f"  Estimated computational cost: {total_cost}")
            print()
        
        # Visualize complexity growth
        T_values, costs = zip(*complexity_analysis)
        plt.figure(figsize=(10, 6))
        plt.plot(T_values, costs, 'ro-', linewidth=2, markersize=8)
        plt.xlabel('Sequence Length T')
        plt.ylabel('Computational Cost (arbitrary units)')
        plt.title('Computational Complexity vs Sequence Length')
        plt.yscale('log')
        plt.grid(True, alpha=0.3)
        plt.show()
        
        print("💡 Solution: ELBO provides a tractable alternative!")

# Test sequential likelihood demo (uncomment after implementing TODOs)
# demo = SequentialLikelihoodDemo(data_dim=2, T=10)
# demo.demonstrate_complexity_explosion()
# demo.marginal_likelihood_approximation(test_data[:1], n_trajectories=50)

### Task 2.2: The Forward Process Implementation

In [None]:
def demonstrate_forward_process():
    """
    Show the forward diffusion process in action
    """
    print("=== Forward Diffusion Process ===\n")
    
    # Take a single data point
    x0 = test_data[0:1]  # Single sample
    print(f"Original data point: {x0.squeeze().cpu().numpy()}")
    
    # Show sequential vs direct jump
    timesteps_to_show = [0, 5, 10, 20, 30, 49]
    
    fig, axes = plt.subplots(2, len(timesteps_to_show), figsize=(15, 6))
    
    demo = SequentialLikelihoodDemo(data_dim=2, T=50)
    
    # Sequential forward process
    print("\nSequential forward process:")
    try:
        trajectory = demo.forward_trajectory(x0)
        for i, t in enumerate(timesteps_to_show):
            if trajectory is not None:
                x_t = trajectory[t].cpu().numpy().squeeze()
                axes[0, i].scatter(x_t[0], x_t[1], c='blue', s=100)
                axes[0, i].set_title(f't={t}')
                axes[0, i].set_xlim(-4, 4)
                axes[0, i].set_ylim(-4, 4)
                axes[0, i].grid(True, alpha=0.3)
        axes[0, 0].set_ylabel('Sequential\nForward')
    except:
        print("Implement forward_trajectory first")
    
    # Direct jump forward process
    print("Direct jump forward process:")
    try:
        for i, t in enumerate(timesteps_to_show):
            x_t = demo.direct_jump_forward(x0, t).cpu().numpy().squeeze()
            axes[1, i].scatter(x_t[0], x_t[1], c='red', s=100)
            axes[1, i].set_title(f't={t}')
            axes[1, i].set_xlim(-4, 4)
            axes[1, i].set_ylim(-4, 4)
            axes[1, i].grid(True, alpha=0.3)
        axes[1, 0].set_ylabel('Direct Jump\nForward')
    except:
        print("Implement direct_jump_forward first")
    
    plt.tight_layout()
    plt.show()
    
    print("Key insight: Direct jump is much more efficient!")
    print("This efficiency will be crucial for ELBO computation.")

# Run forward process demonstration
demonstrate_forward_process()

---

## Part 3: ELBO Derivation for Diffusion Models (25 minutes)

### Task 3.1: Implement the ELBO Decomposition

**Your Mission**: Derive the three-forces decomposition of the diffusion ELBO step by step.

In [None]:
class DiffusionELBODerivation:
    """
    Implement the complete ELBO derivation for diffusion models.
    You'll walk through each algebraic step to transform the intractable likelihood
    into three interpretable terms.
    """
    
    def __init__(self, noise_schedule: Dict[str, torch.Tensor]):
        self.noise_schedule = noise_schedule
        self.T = noise_schedule['T']
        self.betas = noise_schedule['betas']
        self.alphas = noise_schedule['alphas']
        self.alphas_cumprod = noise_schedule['alphas_cumprod']
    
    def log_probability_factorizations(self, x_trajectory: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        TODO: Implement the Markovian factorizations
        
        Compute both:
        1. Generative model: p_θ(x_{0:T}) = p(x_T) ∏ p_θ(x_{t-1}|x_t)
        2. Forward process: q(x_{1:T}|x_0) = ∏ q(x_t|x_{t-1})
        
        This is the foundation for ELBO derivation.
        
        Args:
            x_trajectory: List of states [x_0, x_1, ..., x_T]
            
        Returns:
            Dictionary with log probability components
        """
        # TODO: Your implementation here
        # For forward process:
        # Step 1: Compute log q(x_t|x_{t-1}) for each t
        # Step 2: Sum all forward steps
        
        # For generative model (simplified for demonstration):
        # Step 1: Compute log p(x_T) assuming N(0, I)
        # Step 2: Compute log p_θ(x_{t-1}|x_t) for each t (simplified)
        # Step 3: Combine prior and reverse terms
        
        # Return both factorizations for ELBO computation
        pass
    
    def implement_elbo_step1_separation(self, x_trajectory: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        TODO: Implement ELBO Step 1 - Strategic Term Separation
        
        Starting from: ELBO = E_q[log p_θ(x_{0:T}) - log q(x_{1:T}|x_0)]
        
        Separate into:
        1. log p(x_T) - prior term
        2. log p_θ(x_0|x_1) - reconstruction term  
        3. Σ log p_θ(x_{t-1}|x_t) - reverse denoising terms
        4. Σ log q(x_t|x_{t-1}) - forward denoising terms
        
        Args:
            x_trajectory: Complete trajectory [x_0, ..., x_T]
            
        Returns:
            Dictionary with separated terms
        """
        # TODO: Your implementation here
        # Step 1: Extract boundary terms (t=0 and t=T)
        # Step 2: Identify bulk terms (t=2 to T-1)
        # Step 3: Group terms for strategic pairing
        # Step 4: Return organized components
        pass
    
    def implement_elbo_step2_reindexing(self, separated_terms: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        TODO: Implement ELBO Step 2 - Strategic Reindexing
        
        Align the indices of forward and reverse sums for proper comparison.
        Transform forward sum to match reverse sum indexing.
        
        Args:
            separated_terms: Output from step 1
            
        Returns:
            Dictionary with aligned indices
        """
        # TODO: Your implementation here
        # Step 1: Split forward sum: separate last term log q(x_T|x_{T-1})
        # Step 2: Reindex remaining forward terms to align with reverse terms
        # Step 3: Verify both sums now run from t=2 to T
        # Step 4: Return aligned terms for comparison
        pass
    
    def implement_elbo_step3_bayes_rule(self, aligned_terms: Dict[str, torch.Tensor], 
                                       x_trajectory: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        TODO: Implement ELBO Step 3 - Bayes Rule Transformation
        
        Transform mismatched comparisons into proper reverse process comparisons.
        
        Key insight: We need to compare p_θ(x_{t-1}|x_t) with q(x_{t-1}|x_t, x_0)
        not with q(x_{t-1}|x_{t-2}).
        
        Use: q(x_{t-1}|x_t, x_0) = q(x_t|x_{t-1})q(x_{t-1}|x_0) / q(x_t|x_0)
        
        Args:
            aligned_terms: Output from step 2
            x_trajectory: Complete trajectory for Bayes rule application
            
        Returns:
            Dictionary with proper reverse comparisons
        """
        # TODO: Your implementation here
        # Step 1: Apply Bayes rule to create q(x_{t-1}|x_t, x_0)
        # Step 2: Substitute Bayes rule result into aligned terms
        # Step 3: Rearrange using logarithm properties
        # Step 4: Group terms to create KL divergences
        pass
    
    def implement_elbo_step4_final_form(self, bayes_terms: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        TODO: Implement ELBO Step 4 - Final KL Divergence Form
        
        Transform grouped terms into the final three-forces decomposition:
        1. Reconstruction: E_q[log p_θ(x_0|x_1)]
        2. Prior matching: KL(q(x_T|x_0) || p(x_T))
        3. Denoising: Σ KL(q(x_{t-1}|x_t, x_0) || p_θ(x_{t-1}|x_t))
        
        Args:
            bayes_terms: Output from step 3
            
        Returns:
            Dictionary with final ELBO decomposition
        """
        # TODO: Your implementation here
        # Step 1: Identify difference-of-logs patterns
        # Step 2: Convert to KL divergence form
        # Step 3: Organize into three interpretable terms
        # Step 4: Return final ELBO decomposition
        pass
    
    def complete_elbo_derivation(self, x_trajectory: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Execute the complete ELBO derivation pipeline
        
        Walk through all four steps to transform intractable likelihood
        into tractable three-forces decomposition.
        """
        print("=== Complete ELBO Derivation ===\n")
        
        # Step 1: Strategic separation
        print("Step 1: Strategic term separation...")
        separated = self.implement_elbo_step1_separation(x_trajectory)
        
        # Step 2: Index alignment  
        print("Step 2: Strategic reindexing...")
        aligned = self.implement_elbo_step2_reindexing(separated)
        
        # Step 3: Bayes rule transformation
        print("Step 3: Bayes rule transformation...")
        bayes_transformed = self.implement_elbo_step3_bayes_rule(aligned, x_trajectory)
        
        # Step 4: Final KL form
        print("Step 4: Final KL divergence form...")
        final_elbo = self.implement_elbo_step4_final_form(bayes_transformed)
        
        print("✓ ELBO derivation complete!")
        return final_elbo
    
    def visualize_derivation_steps(self):
        """
        Create visualization showing the derivation progression
        """
        fig, ax = plt.subplots(1, 1, figsize=(14, 8))
        
        steps = [
            "Intractable\nLikelihood",
            "Strategic\nSeparation", 
            "Index\nAlignment",
            "Bayes Rule\nTransform",
            "Final KL\nForm"
        ]
        
        descriptions = [
            "log p(x₀) = ?\nMarginal integral",
            "Boundary terms\nvs bulk terms",
            "Align forward\nand reverse sums",
            "Create proper\nreverse comparisons", 
            "Three forces:\nReconstruction,\nPrior, Denoising"
        ]
        
        colors = ['red', 'orange', 'yellow', 'lightgreen', 'green']
        
        # Draw progression
        y_pos = 0.7
        x_positions = [0.1, 0.3, 0.5, 0.7, 0.9]
        
        for i, (step, desc, color, x) in enumerate(zip(steps, descriptions, colors, x_positions)):
            # Step box
            ax.text(x, y_pos, step, ha='center', va='center', fontsize=11, weight='bold',
                   bbox=dict(boxstyle="round,pad=0.5", facecolor=color, alpha=0.8))
            
            # Description
            ax.text(x, y_pos - 0.25, desc, ha='center', va='center', fontsize=9,
                   bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))
            
            # Arrow to next step
            if i < len(steps) - 1:
                ax.arrow(x + 0.05, y_pos, 0.1, 0, head_width=0.02, head_length=0.02,
                        fc='black', ec='black')
        
        # Add mathematical expressions
        math_expressions = [
            "∫ p(x₀:T) dx₁:T",
            "log p(xT) +\nlog pθ(x₀|x₁) +\nΣ log pθ(x_{t-1}|xt)",
            "Aligned sums\nt=2 to T",
            "Compare\npθ(x_{t-1}|xt) vs\nq(x_{t-1}|xt,x₀)",
            "E[log pθ(x₀|x₁)] -\nKL(q(xT|x₀)||p(xT)) -\nΣ KL(...)"
        ]
        
        for i, (expr, x) in enumerate(zip(math_expressions, x_positions)):
            ax.text(x, 0.2, expr, ha='center', va='center', fontsize=8,
                   bbox=dict(boxstyle="round,pad=0.2", facecolor='lightblue', alpha=0.6))
        
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_title('ELBO Derivation: From Intractable to Tractable', fontsize=14, weight='bold')
        ax.axis('off')
        
        plt.tight_layout()
        plt.show()

# Test ELBO derivation (uncomment after implementing TODOs)
# elbo_derivation = DiffusionELBODerivation(noise_schedule)
# elbo_derivation.visualize_derivation_steps()

# # Test with sample trajectory
# demo = SequentialLikelihoodDemo(data_dim=2, T=50)
# try:
#     sample_trajectory = demo.forward_trajectory(test_data[0:1])
#     if sample_trajectory is not None:
#         elbo_result = elbo_derivation.complete_elbo_derivation(sample_trajectory)
#         print("ELBO derivation successful!")
# except:
#     print("Implement trajectory methods first")

### Task 3.2: The Tractable Reverse Distribution

**Your Mission**: Implement the key insight that makes diffusion training possible.

In [None]:
class TractableReverseDistribution:
    """
    Implement the tractable reverse distribution q(x_{t-1}|x_t, x_0).
    This is the mathematical breakthrough that enables practical diffusion training.
    """
    
    def __init__(self, noise_schedule: Dict[str, torch.Tensor]):
        self.noise_schedule = noise_schedule
        self.betas = noise_schedule['betas']
        self.alphas = noise_schedule['alphas']
        self.alphas_cumprod = noise_schedule['alphas_cumprod']
    
    def bayes_rule_transformation(self, x_t: torch.Tensor, x_0: torch.Tensor, t: int) -> Dict[str, torch.Tensor]:
        """
        TODO: Implement Bayes rule for reverse distribution
        
        Compute q(x_{t-1}|x_t, x_0) using:
        q(x_{t-1}|x_t, x_0) = q(x_t|x_{t-1}) * q(x_{t-1}|x_0) / q(x_t|x_0)
        
        All terms are Gaussian, so the result is also Gaussian!
        
        Args:
            x_t: Current noisy state
            x_0: Original clean data
            t: Current timestep
            
        Returns:
            Dictionary with mean and variance of q(x_{t-1}|x_t, x_0)
        """
        # TODO: Your implementation here
        # Step 1: Extract noise schedule parameters for timestep t
        # Step 2: Apply the analytical Gaussian formula (provided in lecture)
        # Step 3: Compute optimal mean using the interpolation formula
        # Step 4: Compute optimal variance (fixed, no learning required!)
        # Step 5: Return both mean and variance
        pass
    
    def optimal_reverse_mean(self, x_t: torch.Tensor, x_0: torch.Tensor, t: int) -> torch.Tensor:
        """
        TODO: Implement the optimal reverse mean
        
        Use the beautiful interpolation formula:
        μ̃_t(x_t, x_0) = (√ᾱ_{t-1} β_t)/(1-ᾱ_t) * x_0 + (√α_t (1-ᾱ_{t-1}))/(1-ᾱ_t) * x_t
        
        This tells us exactly what the optimal denoising step should predict!
        
        Args:
            x_t: Current noisy state
            x_0: Original clean data  
            t: Current timestep
            
        Returns:
            Optimal mean for the reverse step
        """
        # TODO: Your implementation here
        # Step 1: Get ᾱ_t, ᾱ_{t-1}, α_t, β_t from noise schedule
        # Step 2: Compute coefficient for x_0 term
        # Step 3: Compute coefficient for x_t term
        # Step 4: Verify coefficients sum to 1 (weighted average)
        # Step 5: Return weighted combination
        pass
    
    def optimal_reverse_variance(self, t: int) -> torch.Tensor:
        """
        TODO: Implement the optimal reverse variance
        
        Use the formula: σ̃²_t = (1-ᾱ_{t-1})/(1-ᾱ_t) * β_t
        
        This is fixed and requires no learning!
        
        Args:
            t: Current timestep
            
        Returns:
            Optimal variance for the reverse step
        """
        # TODO: Your implementation here
        # Step 1: Get ᾱ_t, ᾱ_{t-1}, β_t from noise schedule
        # Step 2: Apply the variance formula
        # Step 3: Return fixed variance (no parameters to learn!)
        pass
    
    def demonstrate_interpolation_weights(self):
        """
        Show how the interpolation weights change across timesteps
        """
        print("=== Optimal Mean Interpolation Analysis ===\n")
        
        timesteps = torch.arange(1, self.noise_schedule['T'])
        
        # Compute weights for each timestep
        weight_x0_list = []
        weight_xt_list = []
        
        for t in timesteps:
            # Extract parameters
            alpha_t = self.alphas[t]
            alpha_cumprod_t = self.alphas_cumprod[t]
            alpha_cumprod_t_minus_1 = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
            beta_t = self.betas[t]
            
            # Compute weights
            weight_x0 = (torch.sqrt(alpha_cumprod_t_minus_1) * beta_t) / (1 - alpha_cumprod_t)
            weight_xt = (torch.sqrt(alpha_t) * (1 - alpha_cumprod_t_minus_1)) / (1 - alpha_cumprod_t)
            
            weight_x0_list.append(weight_x0.item())
            weight_xt_list.append(weight_xt.item())
        
        # Visualize weight evolution
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
        
        # Plot 1: Individual weights
        ax1.plot(timesteps, weight_x0_list, 'b-', linewidth=2, label='Weight for x₀ (clean)')
        ax1.plot(timesteps, weight_xt_list, 'r-', linewidth=2, label='Weight for xₜ (noisy)')
        ax1.set_xlabel('Timestep t')
        ax1.set_ylabel('Weight')
        ax1.set_title('Interpolation Weights')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Weight sum (should be 1)
        weight_sums = [w0 + wt for w0, wt in zip(weight_x0_list, weight_xt_list)]
        ax2.plot(timesteps, weight_sums, 'g-', linewidth=2)
        ax2.axhline(y=1.0, color='k', linestyle='--', alpha=0.7, label='Expected sum = 1')
        ax2.set_xlabel('Timestep t')
        ax2.set_ylabel('Sum of Weights')
        ax2.set_title('Weight Sum Verification')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Relative dominance
        weight_ratios = [w0 / wt if wt > 0 else 0 for w0, wt in zip(weight_x0_list, weight_xt_list)]
        ax3.plot(timesteps, weight_ratios, 'purple', linewidth=2)
        ax3.axhline(y=1.0, color='k', linestyle='--', alpha=0.7, label='Equal weights')
        ax3.set_xlabel('Timestep t')
        ax3.set_ylabel('Weight Ratio (x₀/xₜ)')
        ax3.set_title('Clean vs Noisy Dominance')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        ax3.set_yscale('log')
        
        plt.tight_layout()
        plt.show()
        
        print("Key insights:")
        print("• Early timesteps: Trust noisy observation more (small corruption)")
        print("• Late timesteps: Trust clean target more (heavy corruption)")
        print("• Weights always sum to 1 (perfect interpolation)")
        print("• Adaptive balancing based on noise level")
    
    def visualize_reverse_distribution(self, x_0: torch.Tensor, timesteps_to_show: List[int] = [5, 15, 25, 35, 45]):
        """
        Visualize the tractable reverse distribution at different timesteps
        """
        print("=== Tractable Reverse Distribution Visualization ===\n")
        
        fig, axes = plt.subplots(2, len(timesteps_to_show), figsize=(15, 8))
        
        # Create forward trajectory
        demo = SequentialLikelihoodDemo(data_dim=2, T=50)
        
        for i, t in enumerate(timesteps_to_show):
            try:
                # Get x_t using direct jump
                x_t = demo.direct_jump_forward(x_0, t)
                
                # Compute optimal reverse distribution
                optimal_mean = self.optimal_reverse_mean(x_t, x_0, t)
                optimal_var = self.optimal_reverse_variance(t)
                
                # Plot current state x_t
                axes[0, i].scatter(x_t[0, 0].cpu(), x_t[0, 1].cpu(), c='red', s=100, label=f'xₜ (t={t})')
                axes[0, i].scatter(x_0[0, 0].cpu(), x_0[0, 1].cpu(), c='blue', s=100, label='x₀ (clean)')
                
                if optimal_mean is not None:
                    axes[0, i].scatter(optimal_mean[0, 0].cpu(), optimal_mean[0, 1].cpu(), 
                                     c='green', s=100, label='Optimal μ̃')
                
                axes[0, i].set_title(f'Timestep {t}')
                axes[0, i].legend(fontsize=8)
                axes[0, i].grid(True, alpha=0.3)
                axes[0, i].set_xlim(-4, 4)
                axes[0, i].set_ylim(-4, 4)
                
                # Plot variance evolution
                if optimal_var is not None:
                    axes[1, i].bar(0, optimal_var.item(), color='orange', alpha=0.7)
                    axes[1, i].set_title(f'Variance: {optimal_var.item():.4f}')
                    axes[1, i].set_ylim(0, 0.1)
                
            except Exception as e:
                axes[0, i].text(0.5, 0.5, 'Implement\nTODOs first', ha='center', va='center')
                axes[1, i].text(0.5, 0.5, 'Implement\nTODOs first', ha='center', va='center')
        
        axes[0, 0].set_ylabel('Spatial Distribution')
        axes[1, 0].set_ylabel('Optimal Variance')
        
        plt.tight_layout()
        plt.show()

# Test tractable reverse distribution (uncomment after implementing TODOs)
# reverse_dist = TractableReverseDistribution(noise_schedule)
# reverse_dist.demonstrate_interpolation_weights()
# reverse_dist.visualize_reverse_distribution(test_data[0:1])

---

## Part 4: The Noise Prediction Reparameterization (20 minutes)

### Task 4.1: Implement the Noise Prediction Breakthrough

**Your Mission**: Implement the reparameterization that transforms diffusion training from image prediction to noise prediction.

In [None]:
class NoisePredictionReparameterization:
    """
    Implement the noise prediction reparameterization.
    This transforms complex denoising into simple noise prediction!
    """
    
    def __init__(self, noise_schedule: Dict[str, torch.Tensor]):
        self.noise_schedule = noise_schedule
        self.betas = noise_schedule['betas']
        self.alphas = noise_schedule['alphas']
        self.alphas_cumprod = noise_schedule['alphas_cumprod']
    
    def forward_process_with_noise(self, x_0: torch.Tensor, t: int, epsilon: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        TODO: Implement forward process with explicit noise tracking
        
        Use: x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε, where ε ~ N(0,I)
        
        This explicit form will enable noise prediction training.
        
        Args:
            x_0: Clean data
            t: Timestep
            epsilon: Noise vector (if None, sample new noise)
            
        Returns:
            Tuple of (x_t, epsilon_used)
        """
        # TODO: Your implementation here
        # Step 1: Sample noise ε ~ N(0,I) if not provided
        # Step 2: Get ᾱ_t from noise schedule
        # Step 3: Compute mean coefficient: √ᾱ_t
        # Step 4: Compute noise coefficient: √(1-ᾱ_t)
        # Step 5: Apply the formula: x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε
        # Step 6: Return both x_t and the noise used
        pass
    
    def solve_for_x0(self, x_t: torch.Tensor, epsilon: torch.Tensor, t: int) -> torch.Tensor:
        """
        TODO: Implement solving for x_0 given x_t and noise
        
        From x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε
        Solve: x_0 = (x_t - √(1-ᾱ_t) * ε) / √ᾱ_t
        
        This shows how knowing the noise allows us to recover clean data!
        
        Args:
            x_t: Noisy observation
            epsilon: Noise vector
            t: Timestep
            
        Returns:
            Recovered x_0
        """
        # TODO: Your implementation here
        # Step 1: Get ᾱ_t from noise schedule
        # Step 2: Apply the inversion formula
        # Step 3: Return recovered x_0
        pass
    
    def reparameterize_optimal_mean(self, x_t: torch.Tensor, epsilon: torch.Tensor, t: int) -> torch.Tensor:
        """
        TODO: Implement noise-parameterized optimal mean
        
        Transform the optimal mean formula to use noise instead of x_0:
        μ̃_t(x_t, ε) = (1/√α_t) * (x_t - (1-α_t)/√(1-ᾱ_t) * ε)
        
        This is the key insight: optimal denoising = noise prediction + simple arithmetic!
        
        Args:
            x_t: Current noisy state
            epsilon: True noise that was added
            t: Current timestep
            
        Returns:
            Optimal mean in terms of noise
        """
        # TODO: Your implementation here
        # Step 1: Get α_t and ᾱ_t from noise schedule
        # Step 2: Compute the noise coefficient: (1-α_t)/√(1-ᾱ_t)
        # Step 3: Compute the scaling factor: 1/√α_t
        # Step 4: Apply the reparameterized formula
        # Step 5: Return noise-parameterized optimal mean
        pass
    
    def demonstrate_noise_prediction_equivalence(self, x_0: torch.Tensor, t: int):
        """
        Demonstrate that noise prediction is equivalent to optimal denoising
        """
        print(f"=== Noise Prediction Equivalence Demo (t={t}) ===\n")
        
        # Generate noisy sample with known noise
        try:
            x_t, true_epsilon = self.forward_process_with_noise(x_0, t)
            
            print(f"Original x_0: {x_0.squeeze().cpu().numpy()}")
            print(f"Noisy x_t: {x_t.squeeze().cpu().numpy()}")
            print(f"True noise ε: {true_epsilon.squeeze().cpu().numpy()}")
            
            # Method 1: Direct optimal mean (using x_0)
            reverse_dist = TractableReverseDistribution(self.noise_schedule)
            optimal_mean_direct = reverse_dist.optimal_reverse_mean(x_t, x_0, t)
            
            # Method 2: Noise-parameterized optimal mean
            optimal_mean_noise = self.reparameterize_optimal_mean(x_t, true_epsilon, t)
            
            if optimal_mean_direct is not None and optimal_mean_noise is not None:
                print(f"Optimal mean (direct): {optimal_mean_direct.squeeze().cpu().numpy()}")
                print(f"Optimal mean (noise): {optimal_mean_noise.squeeze().cpu().numpy()}")
                
                # Check equivalence
                difference = torch.abs(optimal_mean_direct - optimal_mean_noise).max()
                print(f"Maximum difference: {difference.item():.8f}")
                print(f"Equivalent: {torch.allclose(optimal_mean_direct, optimal_mean_noise, atol=1e-6)}")
            
            # Test noise recovery
            recovered_x0 = self.solve_for_x0(x_t, true_epsilon, t)
            if recovered_x0 is not None:
                recovery_error = torch.abs(x_0 - recovered_x0).max()
                print(f"x_0 recovery error: {recovery_error.item():.8f}")
            
        except:
            print("Implement the TODO methods first")
    
    def visualize_noise_prediction_training(self):
        """
        Visualize how noise prediction training works
        """
        print("=== Noise Prediction Training Visualization ===\n")
        
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        
        # Take a sample point
        x_0 = test_data[0:1]
        timesteps = [10, 20, 30, 40]
        
        for i, t in enumerate(timesteps):
            try:
                # Generate training sample
                x_t, true_epsilon = self.forward_process_with_noise(x_0, t)
                
                # Show the training data
                axes[0, i].scatter(x_0[0, 0].cpu(), x_0[0, 1].cpu(), c='blue', s=100, label='x₀ (clean)')
                axes[0, i].scatter(x_t[0, 0].cpu(), x_t[0, 1].cpu(), c='red', s=100, label='xₜ (noisy)')
                
                # Show the noise vector
                axes[0, i].arrow(x_t[0, 0].cpu(), x_t[0, 1].cpu(), 
                               true_epsilon[0, 0].cpu(), true_epsilon[0, 1].cpu(),
                               head_width=0.1, head_length=0.1, fc='green', ec='green', alpha=0.7)
                axes[0, i].text(x_t[0, 0].cpu() + true_epsilon[0, 0].cpu()/2, 
                               x_t[0, 1].cpu() + true_epsilon[0, 1].cpu()/2, 
                               'ε', fontsize=14, color='green', weight='bold')
                
                axes[0, i].set_title(f't={t}')
                axes[0, i].legend(fontsize=8)
                axes[0, i].grid(True, alpha=0.3)
                axes[0, i].set_xlim(-4, 4)
                axes[0, i].set_ylim(-4, 4)
                
                # Show noise magnitude over time
                noise_magnitude = torch.norm(true_epsilon).item()
                axes[1, i].bar(0, noise_magnitude, color='green', alpha=0.7)
                axes[1, i].set_title(f'||ε||: {noise_magnitude:.3f}')
                axes[1, i].set_ylim(0, 3)
                
            except:
                axes[0, i].text(0.5, 0.5, 'Implement\nTODOs first', ha='center', va='center')
                axes[1, i].text(0.5, 0.5, 'Implement\nTODOs first', ha='center', va='center')
        
        axes[0, 0].set_ylabel('Training Sample')
        axes[1, 0].set_ylabel('Noise Magnitude')
        
        plt.tight_layout()
        plt.show()
        
        print("Training insight:")
        print("• Network learns: ε_θ(x_t, t) ≈ ε")
        print("• Loss: ||ε - ε_θ(x_t, t)||²")
        print("• Simple MSE on noise prediction!")

# Test noise prediction reparameterization (uncomment after implementing TODOs)
# noise_reparam = NoisePredictionReparameterization(noise_schedule)

# # Test equivalence
# for t in [5, 15, 25, 35]:
#     noise_reparam.demonstrate_noise_prediction_equivalence(test_data[0:1], t)
#     print()

# noise_reparam.visualize_noise_prediction_training()

### Task 4.2: Implement the Simple Training Algorithm

**Your Mission**: Transform the complex ELBO into simple noise prediction training.

In [None]:
class SimpleDiffusionTraining:
    """
    Implement the simplified diffusion training algorithm.
    Show how complex ELBO theory reduces to elegant practice.
    """
    
    def __init__(self, noise_schedule: Dict[str, torch.Tensor], data_dim: int = 2):
        self.noise_schedule = noise_schedule
        self.T = noise_schedule['T']
        self.data_dim = data_dim
        
        # Simple noise prediction network
        self.noise_network = nn.Sequential(
            nn.Linear(data_dim + 1, 64),  # +1 for timestep embedding
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, data_dim)  # Predict noise
        ).to(device)
        
        self.noise_reparam = NoisePredictionReparameterization(noise_schedule)
    
    def timestep_embedding(self, t: torch.Tensor) -> torch.Tensor:
        """
        Simple timestep embedding (just normalized timestep)
        """
        return (t.float() / self.T).unsqueeze(-1)
    
    def predict_noise(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Predict noise using the network
        
        Args:
            x_t: Noisy input
            t: Timestep
            
        Returns:
            Predicted noise ε_θ(x_t, t)
        """
        # Embed timestep
        t_embed = self.timestep_embedding(t)
        
        # Concatenate input and timestep
        network_input = torch.cat([x_t, t_embed.expand(x_t.shape[0], -1)], dim=-1)
        
        # Predict noise
        epsilon_pred = self.noise_network(network_input)
        return epsilon_pred
    
    def simple_diffusion_loss(self, x_0: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        TODO: Implement the simple diffusion training loss
        
        Algorithm:
        1. Sample random timestep t
        2. Sample noise ε ~ N(0,I)
        3. Compute x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε
        4. Predict ε_θ(x_t, t)
        5. Compute loss ||ε - ε_θ(x_t, t)||²
        
        This is the elegant simplification of the complex ELBO!
        
        Args:
            x_0: Batch of clean data
            
        Returns:
            Dictionary with loss components
        """
        # TODO: Your implementation here
        # Step 1: Sample random timesteps for each sample in batch
        # Step 2: Sample noise vectors ε ~ N(0,I)
        # Step 3: Create noisy samples x_t using forward_process_with_noise
        # Step 4: Predict noise using the network
        # Step 5: Compute MSE loss between true and predicted noise
        # Step 6: Return loss and auxiliary information
        pass
    
    def train_simple_diffusion(self, data: torch.Tensor, epochs: int = 100, lr: float = 1e-3):
        """
        Train diffusion model using simple noise prediction
        """
        print("=== Training Simple Diffusion Model ===\n")
        
        optimizer = torch.optim.Adam(self.noise_network.parameters(), lr=lr)
        losses = []
        
        for epoch in range(epochs):
            # Training step
            optimizer.zero_grad()
            
            # Compute loss (after students implement simple_diffusion_loss)
            try:
                loss_dict = self.simple_diffusion_loss(data)
                if loss_dict is not None:
                    loss = loss_dict['loss']
                    
                    # Backpropagation
                    loss.backward()
                    optimizer.step()
                    
                    losses.append(loss.item())
                    
                    if epoch % 20 == 0:
                        print(f"Epoch {epoch}: Loss = {loss.item():.6f}")
                else:
                    print("Implement simple_diffusion_loss first")
                    break
            except:
                print("Implement simple_diffusion_loss first")
                break
        
        # Plot training curve
        if losses:
            plt.figure(figsize=(10, 6))
            plt.plot(losses, 'b-', linewidth=2)
            plt.xlabel('Epoch')
            plt.ylabel('Noise Prediction Loss')
            plt.title('Simple Diffusion Training')
            plt.grid(True, alpha=0.3)
            plt.show()
            
            print(f"Final loss: {losses[-1]:.6f}")
            print("Training complete!")
        
        return losses
    
    def test_noise_prediction_accuracy(self, test_data: torch.Tensor, n_samples: int = 10):
        """
        Test how well the trained network predicts noise
        """
        print("=== Testing Noise Prediction Accuracy ===\n")
        
        self.noise_network.eval()
        total_error = 0.0
        
        with torch.no_grad():
            for i in range(n_samples):
                # Sample random test case
                x_0 = test_data[i:i+1]
                t = torch.randint(1, self.T, (1,)).to(device)
                
                # Generate noisy sample with known noise
                try:
                    x_t, true_epsilon = self.noise_reparam.forward_process_with_noise(x_0, t.item())
                    
                    # Predict noise
                    pred_epsilon = self.predict_noise(x_t, t)
                    
                    # Compute error
                    error = torch.mse_loss(pred_epsilon, true_epsilon)
                    total_error += error.item()
                    
                    if i < 3:  # Show first few examples
                        print(f"Sample {i+1}: t={t.item()}")
                        print(f"  True noise: {true_epsilon.squeeze().cpu().numpy()}")
                        print(f"  Pred noise: {pred_epsilon.squeeze().cpu().numpy()}")
                        print(f"  MSE error: {error.item():.6f}")
                        print()
                except:
                    print("Implement forward_process_with_noise first")
                    return
        
        avg_error = total_error / n_samples
        print(f"Average noise prediction error: {avg_error:.6f}")
        
        self.noise_network.train()
    
    def demonstrate_elbo_to_practice_connection(self):
        """
        Show the connection between ELBO theory and practical training
        """
        print("=== ELBO Theory → Practice Connection ===\n")
        
        connections = [
            ("Complex ELBO", "Simple Practice"),
            ("E[log p_θ(x_0|x_1)]", "Reconstruction handled automatically"),
            ("KL(q(x_T|x_0) || p(x_T))", "≈ 0 by design (no learning needed)"),
            ("Σ KL(q(x_{t-1}|x_t,x_0) || p_θ(x_{t-1}|x_t))", "||ε - ε_θ(x_t,t)||² (noise prediction)"),
            ("Intractable optimization", "Simple MSE training"),
            ("Complex mathematics", "Elegant algorithm")
        ]
        
        fig, ax = plt.subplots(1, 1, figsize=(14, 8))
        
        y_positions = [0.85, 0.7, 0.55, 0.4, 0.25, 0.1]
        
        for i, (theory, practice) in enumerate(connections):
            y = y_positions[i]
            
            # Theory side
            ax.text(0.15, y, theory, ha='center', va='center', fontsize=11,
                   bbox=dict(boxstyle="round,pad=0.5", facecolor='lightblue', alpha=0.8))
            
            # Arrow
            ax.arrow(0.4, y, 0.2, 0, head_width=0.02, head_length=0.03,
                    fc='black', ec='black')
            
            # Practice side
            ax.text(0.85, y, practice, ha='center', va='center', fontsize=11,
                   bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgreen', alpha=0.8))
        
        ax.text(0.15, 0.95, 'ELBO Theory', ha='center', fontsize=14, weight='bold', color='blue')
        ax.text(0.85, 0.95, 'Practical Training', ha='center', fontsize=14, weight='bold', color='green')
        
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_title('The Mathematical Bridge: From Theory to Practice', fontsize=16, weight='bold')
        ax.axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print("The profound achievement:")
        print("🎯 Complex ELBO mathematics → Simple MSE training")
        print("🚀 Intractable optimization → Practical algorithm")
        print("💎 Theoretical elegance → State-of-the-art results")

# Test simple diffusion training (uncomment after implementing TODOs)
# simple_trainer = SimpleDiffusionTraining(noise_schedule, data_dim=2)
# simple_trainer.demonstrate_elbo_to_practice_connection()

# # Train on subset of data
# losses = simple_trainer.train_simple_diffusion(test_data[:50], epochs=50)

# # Test noise prediction
# simple_trainer.test_noise_prediction_accuracy(test_data[:10])

---

## Part 5: The Three Forces Analysis (15 minutes)

### Task 5.1: Implement the Three Forces Decomposition

**Your Mission**: Analyze the three forces that shape diffusion learning.

In [None]:
class ThreeForcesAnalysis:
    """
    Implement analysis of the three forces in diffusion ELBO:
    1. Reconstruction: E[log p_θ(x_0|x_1)]
    2. Prior Matching: KL(q(x_T|x_0) || p(x_T))
    3. Denoising: Σ KL(q(x_{t-1}|x_t,x_0) || p_θ(x_{t-1}|x_t))
    """
    
    def __init__(self, noise_schedule: Dict[str, torch.Tensor]):
        self.noise_schedule = noise_schedule
        self.T = noise_schedule['T']
        self.reverse_dist = TractableReverseDistribution(noise_schedule)
        self.noise_reparam = NoisePredictionReparameterization(noise_schedule)
    
    def compute_reconstruction_term(self, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement reconstruction term E[log p_θ(x_0|x_1)]
        
        This measures how well we can recover original data from slight noise.
        Model: p_θ(x_0|x_1) = N(x_0; μ_θ(x_1), σ²I)
        
        Args:
            x_0: Clean data
            x_1: Slightly noisy data
            
        Returns:
            Reconstruction log-likelihood
        """
        # TODO: Your implementation here
        # Step 1: For simplicity, assume the optimal mean (would be learned in practice)
        # Step 2: Use reverse distribution to get optimal mean
        # Step 3: Compute Gaussian log-likelihood
        # Step 4: Return reconstruction term
        pass
    
    def compute_prior_matching_term(self, x_0: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement prior matching term KL(q(x_T|x_0) || p(x_T))
        
        This measures how well the forward endpoint matches pure noise.
        Should be ≈ 0 for well-designed noise schedules!
        
        Args:
            x_0: Clean data
            
        Returns:
            KL divergence to prior
        """
        # TODO: Your implementation here
        # Step 1: Get x_T distribution parameters: q(x_T|x_0) = N(√ᾱ_T x_0, (1-ᾱ_T)I)
        # Step 2: Prior is p(x_T) = N(0, I)
        # Step 3: Compute KL divergence between these Gaussians
        # Step 4: Should be very small when ᾱ_T ≈ 0
        pass
    
    def compute_denoising_terms(self, x_0: torch.Tensor, timesteps: List[int]) -> List[torch.Tensor]:
        """
        TODO: Implement denoising terms KL(q(x_{t-1}|x_t,x_0) || p_θ(x_{t-1}|x_t))
        
        These are the heart of diffusion learning - matching optimal and learned reverse steps.
        
        Args:
            x_0: Clean data
            timesteps: List of timesteps to analyze
            
        Returns:
            List of KL divergences for each timestep
        """
        # TODO: Your implementation here
        # For each timestep t:
        # Step 1: Generate x_t using direct jump
        # Step 2: Compute optimal reverse distribution q(x_{t-1}|x_t,x_0)
        # Step 3: For demonstration, use a simple learned model p_θ(x_{t-1}|x_t)
        # Step 4: Compute KL divergence between optimal and learned
        # Step 5: Return list of KL values
        pass
    
    def analyze_three_forces(self, x_0: torch.Tensor):
        """
        Complete analysis of the three forces
        """
        print("=== Three Forces Analysis ===\n")
        
        try:
            # Force 1: Reconstruction
            x_1, _ = self.noise_reparam.forward_process_with_noise(x_0, 1)
            recon_term = self.compute_reconstruction_term(x_0, x_1)
            
            # Force 2: Prior matching
            prior_term = self.compute_prior_matching_term(x_0)
            
            # Force 3: Denoising (sample a few timesteps)
            sample_timesteps = [5, 15, 25, 35, 45]
            denoising_terms = self.compute_denoising_terms(x_0, sample_timesteps)
            
            if all(term is not None for term in [recon_term, prior_term] + denoising_terms):
                print("Force 1 - Reconstruction:")
                print(f"  E[log p_θ(x_0|x_1)] = {recon_term.mean().item():.6f}")
                print("  What it does: Ensures perfect recovery from slight noise")
                print("  Intuition: 'Given a slightly grainy photo, restore it perfectly'")
                print()
                
                print("Force 2 - Prior Matching:")
                print(f"  KL(q(x_T|x_0) || p(x_T)) = {prior_term.mean().item():.6f}")
                print("  What it does: Ensures forward endpoint matches pure noise")
                print("  Beautiful insight: Should be ≈ 0 by design!")
                print()
                
                print("Force 3 - Denoising:")
                for i, (t, kl_val) in enumerate(zip(sample_timesteps, denoising_terms)):
                    print(f"  t={t}: KL = {kl_val.mean().item():.6f}")
                print("  What it does: Learns optimal reverse steps")
                print("  This is where the actual learning happens!")
                
                # Visualize force magnitudes
                self.visualize_three_forces(recon_term, prior_term, denoising_terms, sample_timesteps)
            
        except Exception as e:
            print("Implement the TODO methods first")
    
    def visualize_three_forces(self, recon_term, prior_term, denoising_terms, timesteps):
        """
        Visualize the relative magnitudes of the three forces
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
        
        # Plot 1: Force magnitudes
        forces = ['Reconstruction', 'Prior Matching', 'Avg Denoising']
        magnitudes = [
            -recon_term.mean().item(),  # Negative because it's a likelihood term
            prior_term.mean().item(),
            sum(term.mean().item() for term in denoising_terms) / len(denoising_terms)
        ]
        colors = ['blue', 'red', 'green']
        
        bars = ax1.bar(forces, magnitudes, color=colors, alpha=0.7)
        ax1.set_ylabel('Magnitude')
        ax1.set_title('Three Forces Magnitude Comparison')
        ax1.grid(True, alpha=0.3)
        
        # Add value labels on bars
        for bar, mag in zip(bars, magnitudes):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 0.001,
                    f'{mag:.4f}', ha='center', va='bottom')
        
        # Plot 2: Denoising terms across timesteps
        denoising_mags = [term.mean().item() for term in denoising_terms]
        ax2.plot(timesteps, denoising_mags, 'go-', linewidth=2, markersize=8)
        ax2.set_xlabel('Timestep t')
        ax2.set_ylabel('KL Divergence')
        ax2.set_title('Denoising Force Across Timesteps')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def demonstrate_force_balance(self):
        """
        Show how the three forces balance in the complete ELBO
        """
        print("=== Three Forces Balance ===\n")
        
        # Create conceptual visualization
        fig, ax = plt.subplots(1, 1, figsize=(12, 8))
        
        # Draw the three forces as a triangle
        force_positions = {
            'Reconstruction': (0.5, 0.8),
            'Prior Matching': (0.2, 0.3),
            'Denoising': (0.8, 0.3)
        }
        
        force_colors = {
            'Reconstruction': 'blue',
            'Prior Matching': 'red', 
            'Denoising': 'green'
        }
        
        force_descriptions = {
            'Reconstruction': 'E[log pθ(x₀|x₁)]\n• Perfect recovery\n• Final output quality\n• Single timestep',
            'Prior Matching': 'KL(q(xT|x₀) || p(xT))\n• Endpoint = noise\n• Free by design\n• No learning needed',
            'Denoising': 'Σ KL(q(x_{t-1}|xt,x₀) || pθ(x_{t-1}|xt))\n• Heart of learning\n• T-1 terms\n• Most computation'
        }
        
        # Draw force nodes
        for force, (x, y) in force_positions.items():
            color = force_colors[force]
            description = force_descriptions[force]
            
            # Force circle
            circle = plt.Circle((x, y), 0.08, color=color, alpha=0.7)
            ax.add_patch(circle)
            
            # Force label
            ax.text(x, y, force.split()[0], ha='center', va='center', 
                   fontsize=10, weight='bold', color='white')
            
            # Description box
            if force == 'Reconstruction':
                desc_pos = (x, y + 0.15)
            elif force == 'Prior Matching':
                desc_pos = (x - 0.15, y - 0.15)
            else:  # Denoising
                desc_pos = (x + 0.15, y - 0.15)
            
            ax.text(desc_pos[0], desc_pos[1], description, ha='center', va='center',
                   fontsize=9, bbox=dict(boxstyle="round,pad=0.5", facecolor=color, alpha=0.3))
        
        # Draw connections showing balance
        connections = [
            (force_positions['Reconstruction'], force_positions['Prior Matching']),
            (force_positions['Prior Matching'], force_positions['Denoising']),
            (force_positions['Denoising'], force_positions['Reconstruction'])
        ]
        
        for (x1, y1), (x2, y2) in connections:
            ax.plot([x1, x2], [y1, y2], 'k--', alpha=0.5, linewidth=2)
        
        # Central balance point
        center_x = sum(pos[0] for pos in force_positions.values()) / 3
        center_y = sum(pos[1] for pos in force_positions.values()) / 3
        ax.plot(center_x, center_y, 'ko', markersize=10)
        ax.text(center_x, center_y - 0.1, 'ELBO\nBalance', ha='center', va='center',
               fontsize=12, weight='bold')
        
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_title('The Three Forces of Diffusion Learning', fontsize=16, weight='bold')
        ax.axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print("Key insights:")
        print("🎯 Reconstruction: Quality control for final output")
        print("🎁 Prior Matching: Free by construction (smart design)")
        print("💪 Denoising: Where the learning happens (T-1 terms)")
        print("⚖️  Balance: All three forces shape the final model")

# Test three forces analysis (uncomment after implementing TODOs)
# three_forces = ThreeForcesAnalysis(noise_schedule)
# three_forces.demonstrate_force_balance()
# three_forces.analyze_three_forces(test_data[0:1])

---

## Part 6: Mathematical Validation and Connection (10 minutes)

### Task 6.1: Validate Complete Implementation

**Your Mission**: Test that all mathematical components work together correctly.

In [None]:
def comprehensive_diffusion_validation():
    """
    Comprehensive validation of all diffusion ELBO components
    """
    print("=== Comprehensive Diffusion Validation ===\n")
    
    # Test 1: Forward process consistency
    print("Test 1: Forward Process Consistency")
    demo = SequentialLikelihoodDemo(data_dim=2, T=50)
    x_0 = test_data[0:1]
    
    try:
        # Sequential vs direct jump should be equivalent in distribution
        trajectory = demo.forward_trajectory(x_0)
        direct_x_T = demo.direct_jump_forward(x_0, 49)
        
        if trajectory is not None and direct_x_T is not None:
            seq_x_T = trajectory[-1]
            print(f"Sequential x_T: {seq_x_T.squeeze().cpu().numpy()}")
            print(f"Direct jump x_T: {direct_x_T.squeeze().cpu().numpy()}")
            print("✓ Forward process implementations consistent")
        else:
            print("❌ Implement forward process methods")
    except:
        print("❌ Forward process implementation needed")
    
    # Test 2: Noise prediction equivalence
    print(f"\nTest 2: Noise Prediction Equivalence")
    noise_reparam = NoisePredictionReparameterization(noise_schedule)
    reverse_dist = TractableReverseDistribution(noise_schedule)
    
    try:
        t = 20
        x_t, epsilon = noise_reparam.forward_process_with_noise(x_0, t)
        
        # Two ways to compute optimal mean
        mean_direct = reverse_dist.optimal_reverse_mean(x_t, x_0, t)
        mean_noise = noise_reparam.reparameterize_optimal_mean(x_t, epsilon, t)
        
        if mean_direct is not None and mean_noise is not None:
            diff = torch.abs(mean_direct - mean_noise).max()
            print(f"Mean difference: {diff.item():.8f}")
            print("✓ Noise prediction reparameterization correct" if diff < 1e-6 else "❌ Reparameterization error")
        else:
            print("❌ Implement optimal mean methods")
    except:
        print("❌ Noise prediction implementation needed")
    
    # Test 3: Three forces implementation
    print(f"\nTest 3: Three Forces Components")
    three_forces = ThreeForcesAnalysis(noise_schedule)
    
    try:
        # Test that prior matching is small
        prior_kl = three_forces.compute_prior_matching_term(x_0)
        if prior_kl is not None:
            print(f"Prior matching KL: {prior_kl.mean().item():.6f}")
            print("✓ Prior matching small" if prior_kl.mean() < 0.1 else "❌ Prior matching too large")
        else:
            print("❌ Implement prior matching term")
    except:
        print("❌ Three forces implementation needed")
    
    # Test 4: Training algorithm
    print(f"\nTest 4: Training Algorithm")
    simple_trainer = SimpleDiffusionTraining(noise_schedule, data_dim=2)
    
    try:
        # Test loss computation
        loss_dict = simple_trainer.simple_diffusion_loss(test_data[:5])
        if loss_dict is not None:
            print(f"Training loss: {loss_dict['loss'].item():.6f}")
            print("✓ Training algorithm implemented")
        else:
            print("❌ Implement training loss")
    except:
        print("❌ Training algorithm implementation needed")
    
    print(f"\n🎓 Validation Summary:")
    print("• Forward process: Essential for noise scheduling")
    print("• Noise prediction: Key insight for practical training")
    print("• Three forces: Mathematical foundation of ELBO")
    print("• Training: Bridge from theory to practice")

def demonstrate_elbo_power():
    """
    Demonstrate the power of the ELBO framework
    """
    print("=== The Power of ELBO Framework ===\n")
    
    achievements = [
        "🚫 Intractable Likelihood Problem",
        "🧮 Complex Sequential Dependencies", 
        "📐 Variational Inference Solution",
        "⚡ Tractable Lower Bound",
        "🔧 Three Interpretable Forces",
        "🎯 Simple Noise Prediction",
        "🚀 State-of-the-Art Results"
    ]
    
    fig, ax = plt.subplots(1, 1, figsize=(12, 10))
    
    # Create a flow diagram
    y_positions = [0.9, 0.8, 0.65, 0.5, 0.35, 0.2, 0.05]
    colors = ['red', 'orange', 'yellow', 'lightgreen', 'green', 'blue', 'purple']
    
    for i, (achievement, y, color) in enumerate(zip(achievements, y_positions, colors)):
        # Achievement box
        box_style = "round,pad=0.5" if i != 2 else "round,pad=0.5"  # Highlight ELBO
        ax.text(0.5, y, achievement, ha='center', va='center', fontsize=12, weight='bold',
               bbox=dict(boxstyle=box_style, facecolor=color, alpha=0.8))
        
        # Arrow to next achievement
        if i < len(achievements) - 1:
            ax.arrow(0.5, y - 0.05, 0, -0.05, head_width=0.02, head_length=0.01,
                    fc='gray', ec='gray', alpha=0.7)
    
    # Add side annotations
    annotations = [
        (0.15, 0.9, "Mathematical\nChallenge"),
        (0.85, 0.65, "ELBO\nBreakthrough"),
        (0.15, 0.35, "Practical\nImplementation"),
        (0.85, 0.05, "Real-World\nSuccess")
    ]
    
    for x, y, text in annotations:
        ax.text(x, y, text, ha='center', va='center', fontsize=10,
               bbox=dict(boxstyle="round,pad=0.3", facecolor='lightblue', alpha=0.5))
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('ELBO: From Mathematical Challenge to Practical Success', fontsize=14, weight='bold')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Run comprehensive validation
comprehensive_diffusion_validation()
demonstrate_elbo_power()

---

## Part 7: Connection to Modern Diffusion Models (5 minutes)

### Task 7.1: Bridge to State-of-the-Art

In [None]:
def connect_to_modern_diffusion():
    """
    Connect lab implementations to modern diffusion models
    """
    print("=== Connection to Modern Diffusion Models ===\n")
    
    lab_to_sota = [
        ("Our Implementation", "State-of-the-Art Models"),
        ("Simple 2D data", "High-resolution images (1024×1024)"),
        ("Linear noise schedule", "Cosine/learned schedules"),
        ("Basic MLP network", "U-Net architectures with attention"),
        ("T=50 timesteps", "T=1000+ timesteps"),
        ("MSE noise prediction", "Advanced loss functions"),
        ("No conditioning", "Text/class conditioning"),
        ("Direct sampling", "Advanced samplers (DDIM, DPM)")
    ]
    
    print("From Lab to Real World:")
    for lab, sota in lab_to_sota:
        print(f"  {lab:25s} → {sota}")
    
    print(f"\nThe mathematical foundation you implemented enables:")
    print("  🎨 DALL-E 2, Midjourney, Stable Diffusion")
    print("  🎵 Audio generation models")
    print("  🧬 Protein structure generation")
    print("  📹 Video synthesis")
    print("  🔬 Scientific data generation")
    
    print(f"\nKey insight: The ELBO framework scales!")
    print("  • Same three forces at any scale")
    print("  • Noise prediction principle universal")
    print("  • Mathematical elegance → practical power")

def preview_advanced_topics():
    """
    Preview advanced diffusion topics
    """
    print("=== Preview: Advanced Diffusion Topics ===\n")
    
    advanced_topics = {
        "Faster Sampling": [
            "DDIM (deterministic sampling)",
            "DPM-Solver (ODE-based)",
            "Progressive distillation",
            "Consistency models"
        ],
        "Better Training": [
            "Classifier-free guidance",
            "Score matching interpretation",
            "Improved noise schedules",
            "Loss function innovations"
        ],
        "Conditional Generation": [
            "Text-to-image synthesis",
            "Inpainting and editing",
            "Style transfer",
            "Controllable generation"
        ],
        "Architecture Advances": [
            "U-Net with attention",
            "Transformer-based diffusion",
            "Latent diffusion models",
            "Cascaded diffusion"
        ]
    }
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    colors = ['lightblue', 'lightgreen', 'lightcoral', 'lightyellow']
    
    for i, (topic, items) in enumerate(advanced_topics.items()):
        ax = axes[i]
        
        # Create topic visualization
        ax.text(0.5, 0.9, topic, ha='center', va='center', fontsize=14, weight='bold',
               bbox=dict(boxstyle="round,pad=0.3", facecolor=colors[i]))
        
        for j, item in enumerate(items):
            y_pos = 0.7 - j * 0.15
            ax.text(0.5, y_pos, f"• {item}", ha='center', va='center', fontsize=10)
        
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.axis('off')
    
    plt.suptitle('Advanced Diffusion Model Topics', fontsize=16, weight='bold')
    plt.tight_layout()
    plt.show()
    
    print("Your mathematical foundation enables all of these advances!")

# Connect to modern developments
connect_to_modern_diffusion()
preview_advanced_topics()

---

## Part 8: Reflection and Mathematical Mastery (5 minutes)

### Task 8.1: Diffusion Mathematics Summary

**Discussion Questions** (Work with your partner):

1. **ELBO Derivation Insights**:
   - Which step in the ELBO derivation was most challenging to understand?
   - How does the three-forces decomposition clarify the learning process?

2. **Noise Prediction Breakthrough**:
   - Why is noise prediction easier than image prediction for neural networks?
   - How does this reparameterization connect to the optimal reverse distribution?

3. **Practical Impact**:
   - How does complex mathematical theory enable simple practical algorithms?
   - What would diffusion models be like without the ELBO framework?

In [None]:
def summarize_diffusion_mathematics():
    """
    Reflect on the diffusion mathematics mastery achieved
    """
    print("=== Your Diffusion Mathematics Journey ===\n")
    
    concepts_mastered = [
        "🔄 Sequential latent variable models",
        "📐 Diffusion ELBO derivation (four algebraic steps)",
        "⚖️  Three forces: Reconstruction, Prior Matching, Denoising",
        "🎯 Tractable reverse distribution via Bayes rule",
        "⚡ Noise prediction reparameterization",
        "🚀 Simple training algorithm from complex theory",
        "🧮 Mathematical validation and testing",
        "🌉 Bridge from theory to state-of-the-art practice"
    ]
    
    print("Mathematical concepts you implemented:")
    for concept in concepts_mastered:
        print(f"  {concept}")
    
    print(f"\n🎓 You now understand the mathematical heart of:")
    print(f"   • Diffusion Models (DDPM, DDIM, Score-based)")
    print(f"   • Modern generative AI (DALL-E, Midjourney, Stable Diffusion)")
    print(f"   • Variational inference for sequential models")
    print(f"   • The connection between theory and practice")
    
    print(f"\n🔬 Key mathematical insights achieved:")
    print(f"   • Why sequential latents avoid approximation errors")
    print(f"   • How ELBO transforms intractable to tractable")
    print(f"   • Why noise prediction is the optimal parameterization")
    print(f"   • How mathematical elegance enables practical success")
    
    # Create mastery visualization
    fig, ax = plt.subplots(1, 1, figsize=(14, 10))
    
    # Mathematical journey progression
    journey_steps = [
        ("Intractable\nLikelihood", 0.1, 0.8, 'red'),
        ("Sequential\nELBO", 0.3, 0.8, 'orange'),
        ("Three Forces\nDecomposition", 0.5, 0.8, 'yellow'),
        ("Tractable\nReverse", 0.7, 0.8, 'lightgreen'),
        ("Noise\nPrediction", 0.9, 0.8, 'green'),
        ("Simple\nTraining", 0.3, 0.5, 'lightblue'),
        ("Mathematical\nValidation", 0.7, 0.5, 'blue'),
        ("State-of-the-Art\nConnection", 0.5, 0.2, 'purple')
    ]
    
    # Draw journey progression
    for i, (step, x, y, color) in enumerate(journey_steps):
        ax.text(x, y, step, ha='center', va='center', fontsize=11, weight='bold',
               bbox=dict(boxstyle="round,pad=0.5", facecolor=color, alpha=0.8))
        
        # Add connections
        if i < 5:  # Top row connections
            if i < 4:
                ax.arrow(x + 0.08, y, 0.12, 0, head_width=0.02, head_length=0.02,
                        fc='gray', ec='gray', alpha=0.7)
        elif i == 5:  # Down from step 2
            ax.arrow(0.3, 0.75, 0, -0.2, head_width=0.02, head_length=0.02,
                    fc='gray', ec='gray', alpha=0.7)
        elif i == 6:  # Down from step 4  
            ax.arrow(0.7, 0.75, 0, -0.2, head_width=0.02, head_length=0.02,
                    fc='gray', ec='gray', alpha=0.7)
        elif i == 7:  # Final convergence
            ax.arrow(0.3, 0.45, 0.15, -0.2, head_width=0.02, head_length=0.02,
                    fc='gray', ec='gray', alpha=0.7)
            ax.arrow(0.7, 0.45, -0.15, -0.2, head_width=0.02, head_length=0.02,
                    fc='gray', ec='gray', alpha=0.7)
    
    # Add achievement badges
    achievements = [
        (0.1, 0.6, "🔧\nDerivation\nMastery"),
        (0.9, 0.6, "⚡\nImplementation\nPower"),
        (0.1, 0.3, "🧮\nValidation\nSkills"),
        (0.9, 0.3, "🌉\nTheory-Practice\nBridge")
    ]
    
    for x, y, achievement in achievements:
        ax.text(x, y, achievement, ha='center', va='center', fontsize=10,
               bbox=dict(boxstyle="round,pad=0.3", facecolor='gold', alpha=0.6))
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('Your Diffusion Mathematics Mastery Journey', fontsize=16, weight='bold')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Summarize the mathematical journey
summarize_diffusion_mathematics()

---

## Implementation Checklist

### Core Mathematical Functions (Students Implement):

**✅ Essential TODOs:**
- [ ] `forward_step()` - Single forward diffusion step
- [ ] `forward_trajectory()` - Complete forward trajectory
- [ ] `direct_jump_forward()` - Efficient analytical forward jump
- [ ] `implement_elbo_step1_separation()` - Strategic term separation
- [ ] `implement_elbo_step2_reindexing()` - Index alignment
- [ ] `implement_elbo_step3_bayes_rule()` - Bayes rule transformation
- [ ] `implement_elbo_step4_final_form()` - Final KL divergence form
- [ ] `bayes_rule_transformation()` - Tractable reverse distribution
- [ ] `optimal_reverse_mean()` - Optimal interpolation formula
- [ ] `optimal_reverse_variance()` - Fixed optimal variance
- [ ] `forward_process_with_noise()` - Explicit noise tracking
- [ ] `solve_for_x0()` - Noise inversion formula
- [ ] `reparameterize_optimal_mean()` - Noise-parameterized mean
- [ ] `simple_diffusion_loss()` - Simple training algorithm
- [ ] `compute_reconstruction_term()` - First ELBO force
- [ ] `compute_prior_matching_term()` - Second ELBO force
- [ ] `compute_denoising_terms()` - Third ELBO force

**✅ Provided Starter Code:**
- [ ] All visualization functions with complete plotting code
- [ ] Training loops and optimization infrastructure
- [ ] Mathematical validation and testing frameworks
- [ ] Connection to modern diffusion models
- [ ] Noise schedule creation and management

---

## Submission Requirements

### What to Submit

Submit your completed Jupyter notebook (.ipynb file) with:

**✅ ELBO Derivation:**
- Complete four-step algebraic derivation implemented correctly
- Strategic separation, reindexing, Bayes rule, and final KL form
- Clear mathematical comments explaining each transformation

**✅ Tractable Reverse Distribution:**
- Optimal mean and variance implementations
- Interpolation weight analysis and visualization
- Connection to forward process via Bayes rule

**✅ Noise Prediction Reparameterization:**
- Forward process with explicit noise tracking
- Noise inversion and reparameterization formulas
- Equivalence demonstrations between parameterizations

**✅ Three Forces Analysis:**
- Implementation of all three ELBO terms
- Force magnitude analysis and interpretation
- Connection between forces and learning dynamics

**✅ Simple Training Algorithm:**
- Complete noise prediction training implementation
- Connection between complex ELBO and simple MSE loss
- Mathematical validation of all components

**✅ Documentation and Insights:**
- Clear explanations of mathematical derivations
- Discussion of implementation challenges and solutions
- Connection between theory and practical state-of-the-art models

---

## Quick Reference: Key Mathematical Formulas

### For Implementation Reference:

**Forward Process Direct Jump:**

In [None]:
# q(x_t | x_0) = N(√ᾱ_t * x_0, (1-ᾱ_t) * I)
mean = torch.sqrt(alpha_cumprod_t) * x_0
variance = 1 - alpha_cumprod_t
x_t = mean + torch.sqrt(variance) * epsilon

**Optimal Reverse Mean:**

In [None]:
# μ̃_t(x_t, x_0) = weight_x0 * x_0 + weight_xt * x_t
weight_x0 = (torch.sqrt(alpha_cumprod_t_minus_1) * beta_t) / (1 - alpha_cumprod_t)
weight_xt = (torch.sqrt(alpha_t) * (1 - alpha_cumprod_t_minus_1)) / (1 - alpha_cumprod_t)
optimal_mean = weight_x0 * x_0 + weight_xt * x_t

**Noise Reparameterization:**

In [None]:
# μ̃_t(x_t, ε) = (1/√α_t) * (x_t - (1-α_t)/√(1-ᾱ_t) * ε)
coeff = (1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)
optimal_mean = (x_t - coeff * epsilon) / torch.sqrt(alpha_t)

**Simple Training Loss:**

In [None]:
# Sample t ~ Uniform{1, ..., T}
# Sample ε ~ N(0, I)  
# Compute x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε
# Loss = ||ε - ε_θ(x_t, t)||²
t = torch.randint(1, T+1, (batch_size,))
epsilon = torch.randn_like(x_0)
x_t = torch.sqrt(alpha_cumprod[t]) * x_0 + torch.sqrt(1 - alpha_cumprod[t]) * epsilon
loss = F.mse_loss(epsilon_pred, epsilon)

**Optimal Reverse Variance:**

In [None]:
# σ̃²_t = (1-ᾱ_{t-1})/(1-ᾱ_t) * β_t
optimal_variance = ((1 - alpha_cumprod_t_minus_1) / (1 - alpha_cumprod_t)) * beta_t

---

## Common Implementation Issues & Solutions

### Debugging Tips:

**Index Confusion:**
- Forward process: x_0 → x_1 → ... → x_T (T+1 states total)
- Reverse process: x_T → x_{T-1} → ... → x_0
- Python indexing: alphas[t] corresponds to timestep t (1-indexed mathematically)

**Numerical Stability:**
- Use `torch.sqrt()` instead of `** 0.5` for gradients
- Add small epsilon (1e-8) when computing log probabilities
- Clamp alpha values to avoid division by zero: `torch.clamp(alpha, min=1e-8)`

**Shape Broadcasting:**
- Noise schedule tensors: shape (T,)
- Batch operations: ensure proper broadcasting with (batch_size, data_dim)
- Use `.view(-1, 1)` or `.unsqueeze(-1)` for proper tensor alignment

**Mathematical Sign Conventions:**
- ELBO terms: reconstruction (+), KL divergences (-)
- Loss for optimization: minimize negative ELBO
- Log probabilities vs probabilities: ensure consistent use

---

