# Lab 3: Mathematical Foundations of Generative Models - Hands-On Implementation
**Course: Diffusion Models: Theory and Applications**  
**Duration: 90 minutes**  
**Team Size: 2 students (same teams from Labs 1-2)**

---

## Learning Objectives
By the end of this lab, students will be able to:
1. **Implement** the complete ELBO derivation from first principles
2. **Build** KL divergence calculations and explore their asymmetric properties
3. **Create** Jensen's inequality demonstrations showing how lower bounds work
4. **Construct** the two-forces analysis of reconstruction vs regularization
5. **Connect** mathematical theory to practical optimization algorithms
6. **Prepare** the foundation for understanding diffusion model mathematics

---

## Lab Setup and Mathematical Framework

### Part 1: Team Setup & Mathematical Mission (10 minutes)

In [None]:
# Mathematical foundations setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
from torch.distributions import Normal, MultivariateNormal
from typing import Tuple, Dict
import time

# Set seeds for reproducible mathematics
torch.manual_seed(42)
np.random.seed(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Mathematical computing on: {device}")

# Create synthetic 2D data for testing
def create_test_data(n_samples: int = 500) -> torch.Tensor:
    """Create 2D spiral data for testing our mathematical implementations"""
    t = torch.linspace(0, 3*math.pi, n_samples)
    x = t * torch.cos(t) + 0.1 * torch.randn(n_samples)
    y = t * torch.sin(t) + 0.1 * torch.randn(n_samples)
    data = torch.stack([x, y], dim=1)
    return data.to(device)

# Generate test data
test_data = create_test_data(500)
print(f"Test data shape: {test_data.shape}")

# Visualize test data
plt.figure(figsize=(8, 6))
plt.scatter(test_data[:, 0].cpu(), test_data[:, 1].cpu(), alpha=0.6, s=20)
plt.title('Test Data: 2D Spiral')
plt.xlabel('X')
plt.ylabel('Y')
plt.grid(True, alpha=0.3)
plt.show()

---

## Part 2: The Intractable Likelihood Crisis (20 minutes)

### Task 2.1: Experience Why Direct Likelihood Fails

**Your Mission**: Implement Monte Carlo likelihood estimation and see why it fails.

In [None]:
class IntractableLikelihoodDemo:
    """
    Demonstrate why direct likelihood computation fails for generative models.
    You'll implement the mathematical computations to see the problems firsthand.
    """
    
    def __init__(self, latent_dim: int = 2, data_dim: int = 2):
        self.latent_dim = latent_dim
        self.data_dim = data_dim
        
        # Simple generative model: p(x|z) = N(f_θ(z), σ²I)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 16),
            nn.ReLU(),
            nn.Linear(16, data_dim)
        ).to(device)
        
        self.noise_std = 0.2
        
    def sample_prior(self, n_samples: int) -> torch.Tensor:
        """
        TODO: Implement prior sampling
        
        Sample from p(z) = N(0, I)
        
        Returns:
            z_samples: (n_samples, latent_dim)
        """
        # TODO: Your implementation here
        pass
        
    def likelihood_given_z(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement p(x|z) computation
        
        Compute the likelihood p(x|z) = N(x; f_θ(z), σ²I)
        
        Steps:
        1. Pass z through decoder to get mean μ = f_θ(z)
        2. Compute Gaussian likelihood with fixed variance σ²
        3. Return likelihood values (not log-likelihood)
        
        Args:
            x: Data points (batch_size, data_dim)
            z: Latent codes (batch_size, latent_dim)
            
        Returns:
            likelihoods: p(x|z) for each pair (batch_size,)
        """
        # TODO: Your implementation here
        # Hint: Use Normal distribution from torch.distributions
        pass
    
    def approximate_marginal_likelihood(self, x: torch.Tensor, n_samples: int = 1000) -> torch.Tensor:
        """
        TODO: Implement Monte Carlo approximation of p(x)
        
        Approximate p(x) = ∫ p(x|z)p(z) dz ≈ (1/K) Σ p(x|z_k) where z_k ~ p(z)
        
        This will fail for high dimensions, showing why we need variational inference!
        
        Steps:
        1. Sample many z values from prior
        2. For each data point x_i, compute p(x_i|z_k) for all sampled z_k
        3. Average these likelihood values
        4. Return the Monte Carlo estimate
        
        Args:
            x: Data points (batch_size, data_dim)
            n_samples: Number of Monte Carlo samples
            
        Returns:
            approx_p_x: Approximated p(x) for each data point (batch_size,)
        """
        # TODO: Your implementation here
        pass
    
    def demonstrate_failure(self, test_data: torch.Tensor):
        """Run the demonstration showing why direct likelihood computation fails"""
        print("=== Demonstrating Why Direct Likelihood Fails ===\n")
        
        # Test with increasing numbers of Monte Carlo samples
        sample_counts = [10, 100, 1000, 5000]
        test_points = test_data[:3]  # Test on 3 points
        
        for n_samples in sample_counts:
            print(f"Using {n_samples} Monte Carlo samples:")
            
            # Time the computation
            start_time = time.time()
            approx_likelihood = self.approximate_marginal_likelihood(test_points, n_samples)
            computation_time = time.time() - start_time
            
            # Run multiple times to show variance
            estimates = []
            for trial in range(5):
                estimate = self.approximate_marginal_likelihood(test_points[:1], n_samples)
                estimates.append(estimate.item())
            
            # Print results and analysis
            print(f"  Time: {computation_time:.3f}s")
            print(f"  Estimates: {approx_likelihood.detach().cpu().numpy()}")
            print(f"  Variance across trials: {np.var(estimates):.6f}")
            print()
        
        print("Analysis:")
        print("❌ Computation time grows linearly with samples")
        print("❌ High variance in estimates")
        print("❌ Most random z give very low p(x|z)")
        print("❌ Completely impractical for high-dimensional z")

# Test your implementation (uncomment after implementing TODOs)
# demo = IntractableLikelihoodDemo()
# demo.demonstrate_failure(test_data)

### Task 2.2: Understand the Bayes Rule Circular Dependency

In [None]:
def demonstrate_bayes_circularity():
    """
    Show the circular dependency problem in Bayes' rule
    """
    print("=== The Bayes Rule Circular Dependency ===\n")
    
    print("Bayes' rule: p(z|x) = p(x|z) * p(z) / p(x)")
    print()
    
    print("What we can compute:")
    print("  ✓ p(x|z): Decoder/likelihood model")
    print("  ✓ p(z): Prior distribution (we choose this)")
    print()
    print("What we can't compute:")
    print("  ✗ p(x): The marginal likelihood = ∫ p(x|z)p(z) dz")
    print()
    
    print("The circular dependency:")
    print("  1. We want p(z|x) to learn about latent factors")
    print("  2. Bayes rule requires p(x) in the denominator")
    print("  3. But p(x) is the same intractable integral!")
    print("  4. Can't compute what we need to learn")
    print()
    
    print("💡 Solution: Variational inference!")
    print("   Approximate p(z|x) with tractable q(z|x)")

# Run the demonstration
demonstrate_bayes_circularity()

---

## Part 3: KL Divergence Implementation (20 minutes)

### Task 3.1: Build KL Divergence from Scratch

**Your Mission**: Implement KL divergence calculations and understand their properties.

In [None]:
class KLDivergenceBuilder:
    """
    Build KL divergence calculations from first principles.
    You'll implement the core mathematical formulas.
    """
    
    def monte_carlo_kl(self, p_samples: torch.Tensor, p_logprob_fn, q_logprob_fn) -> float:
        """
        TODO: Implement KL(p||q) using Monte Carlo estimation
        
        Formula: KL(p||q) = E_p[log p(x) - log q(x)]
        Monte Carlo: KL(p||q) ≈ (1/N) Σ [log p(x_i) - log q(x_i)] where x_i ~ p
        
        Args:
            p_samples: Samples from distribution p
            p_logprob_fn: Function to compute log p(x)
            q_logprob_fn: Function to compute log q(x)
            
        Returns:
            kl_estimate: Estimated KL(p||q)
        """
        # TODO: Implement the Monte Carlo KL estimation
        # Step 1: Compute log p(x) for all samples
        # Step 2: Compute log q(x) for all samples  
        # Step 3: Compute log p(x) - log q(x)
        # Step 4: Take the mean
        pass
    
    def gaussian_kl_closed_form(self, mu1: torch.Tensor, sigma1: torch.Tensor,
                                mu2: torch.Tensor, sigma2: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement closed-form KL for Gaussians
        
        For p = N(μ₁, σ₁²) and q = N(μ₂, σ₂²):
        KL(p||q) = log(σ₂/σ₁) + (σ₁² + (μ₁-μ₂)²)/(2σ₂²) - 1/2
        
        This closed form is crucial for practical VI.
        
        Args:
            mu1, sigma1: Parameters of first Gaussian
            mu2, sigma2: Parameters of second Gaussian
            
        Returns:
            kl: KL divergence
        """
        # TODO: Implement the closed-form formula
        # Hint: Be careful with the formula - there are multiple equivalent forms
        pass
    
    def standard_normal_kl(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement KL divergence to standard normal N(0,I)
        
        This is the most common case in VAEs.
        For p = N(μ, σ²I) and q = N(0, I):
        KL(p||q) = 0.5 * Σ[μ² + σ² - 1 - log(σ²)]
        
        Args:
            mu: Mean vector (..., latent_dim)
            logvar: Log variance vector (..., latent_dim)  [Note: logvar = log(σ²)]
            
        Returns:
            kl: KL divergence for each sample (...,)
        """
        # TODO: Implement the standard normal KL formula
        # Remember: logvar = log(σ²), so σ² = exp(logvar)
        pass
    
    def explore_kl_asymmetry(self):
        """
        Demonstrate KL divergence asymmetry
        """
        print("=== KL Divergence Asymmetry ===\n")
        
        # Define two different Gaussians
        mu1, sigma1 = torch.tensor(0.0), torch.tensor(0.5)  # Narrow
        mu2, sigma2 = torch.tensor(1.0), torch.tensor(1.5)  # Wide
        
        print(f"Distribution p: N({mu1:.1f}, {sigma1:.1f}²) [narrow]")
        print(f"Distribution q: N({mu2:.1f}, {sigma2:.1f}²) [wide]")
        print()
        
        # Compute both directions (after implementing the TODO above)
        kl_p_q = self.gaussian_kl_closed_form(mu1, sigma1, mu2, sigma2)
        kl_q_p = self.gaussian_kl_closed_form(mu2, sigma2, mu1, sigma1)
        
        print(f"KL(p||q): {kl_p_q:.4f}")
        print(f"KL(q||p): {kl_q_p:.4f}")
        print(f"Asymmetry: {abs(kl_p_q - kl_q_p):.4f}")
        
        # Verify with Monte Carlo
        n_samples = 10000
        p_dist = Normal(mu1, sigma1)
        q_dist = Normal(mu2, sigma2)
        p_samples = p_dist.sample((n_samples,))
        
        kl_mc = self.monte_carlo_kl(p_samples, p_dist.log_prob, q_dist.log_prob)
        print(f"Monte Carlo verification: {kl_mc:.4f}")
        
        # Create visualization showing the asymmetry
        self.visualize_asymmetry(mu1, sigma1, mu2, sigma2)
    
    def visualize_asymmetry(self, mu1, sigma1, mu2, sigma2):
        """Create plots showing KL asymmetry effects"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))
        
        # Plot 1: The two distributions
        x = torch.linspace(-3, 5, 1000)
        p_dist = Normal(mu1, sigma1)
        q_dist = Normal(mu2, sigma2)
        
        p_vals = torch.exp(p_dist.log_prob(x))
        q_vals = torch.exp(q_dist.log_prob(x))
        
        ax1.plot(x, p_vals, 'b-', linewidth=2, label=f'p: N({mu1:.1f}, {sigma1:.1f}²)')
        ax1.plot(x, q_vals, 'r-', linewidth=2, label=f'q: N({mu2:.1f}, {sigma2:.1f}²)')
        ax1.set_title('Distributions p and q')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Log ratio p/q
        log_ratio_pq = p_dist.log_prob(x) - q_dist.log_prob(x)
        ax2.plot(x, log_ratio_pq, 'g-', linewidth=2)
        ax2.set_title('log(p/q) - used in KL(p||q)')
        ax2.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Log ratio q/p  
        log_ratio_qp = q_dist.log_prob(x) - p_dist.log_prob(x)
        ax3.plot(x, log_ratio_qp, 'orange', linewidth=2)
        ax3.set_title('log(q/p) - used in KL(q||p)')
        ax3.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Weighted contributions
        p_samples = p_dist.sample((1000,))
        q_samples = q_dist.sample((1000,))
        
        ax4.hist(p_samples, bins=30, alpha=0.5, label='Samples from p', density=True)
        ax4.hist(q_samples, bins=30, alpha=0.5, label='Samples from q', density=True)
        ax4.set_title('Sample distributions')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        print("\nKey insight:")
        print("• KL(p||q): 'mode-covering' - q must cover all of p")
        print("• KL(q||p): 'mode-seeking' - q focuses on high-density regions of p")

# Test your KL divergence implementations (uncomment after implementing TODOs)
# kl_builder = KLDivergenceBuilder()
# kl_builder.explore_kl_asymmetry()

### Task 3.2: Validate Your KL Implementations

In [None]:
def test_kl_implementations():
    """
    Test your KL divergence implementations
    
    Verify that your formulas are correct by comparing different methods.
    """
    print("=== Testing KL Implementations ===\n")
    
    kl_builder = KLDivergenceBuilder()
    
    # Test case 1: Simple Gaussians
    mu1, sigma1 = torch.tensor(1.0), torch.tensor(0.8)
    mu2, sigma2 = torch.tensor(0.0), torch.tensor(1.0)
    
    print("Test 1: Gaussian KL")
    print(f"p = N({mu1:.1f}, {sigma1:.1f}²)")
    print(f"q = N({mu2:.1f}, {sigma2:.1f}²)")
    
    # Test your closed-form implementation
    kl_closed = kl_builder.gaussian_kl_closed_form(mu1, sigma1, mu2, sigma2)
    print(f"Your implementation: {kl_closed:.6f}")
    
    # Test against known analytical result
    # Manual calculation for verification
    analytical = 0.5 * ((sigma1/sigma2)**2 + ((mu1-mu2)/sigma2)**2 - 1 - 2*torch.log(sigma1/sigma2))
    print(f"Expected result: {analytical:.6f}")
    print(f"Match: {torch.allclose(kl_closed, analytical, atol=1e-5) if kl_closed is not None else 'Implement TODO first'}")
    
    # Test your Monte Carlo implementation
    p_dist = Normal(mu1, sigma1)
    q_dist = Normal(mu2, sigma2)
    p_samples = p_dist.sample((10000,))
    kl_mc = kl_builder.monte_carlo_kl(p_samples, p_dist.log_prob, q_dist.log_prob)
    print(f"Monte Carlo: {kl_mc}")
    
    # Test case 2: Standard normal KL
    print(f"\nTest 2: KL to standard normal")
    mu_test = torch.tensor([1.0, -0.5])
    logvar_test = torch.tensor([0.5, -0.2])
    
    # Test your standard normal KL
    kl_std = kl_builder.standard_normal_kl(mu_test, logvar_test)
    print(f"Standard normal KL: {kl_std}")
    
    # Expected result for verification
    expected_std = 0.5 * (mu_test**2 + torch.exp(logvar_test) - 1 - logvar_test).sum()
    print(f"Expected: {expected_std:.6f}")
    
    print("\n✓ All tests completed!")

# Run your tests (uncomment after implementing TODOs)
# test_kl_implementations()

---

## Part 4: Jensen's Inequality Demonstration (15 minutes)

### Task 4.1: Implement Jensen's Inequality

**Your Mission**: Build the mathematical foundation that makes ELBO possible.

In [None]:
class JensensInequalityDemo:
    """
    Implement Jensen's inequality demonstrations.
    This is the key mathematical tool that makes variational inference work.
    """
    
    def demonstrate_concave_property(self):
        """
        TODO: Show Jensen's inequality for the logarithm
        
        For concave functions like log: f(E[X]) ≥ E[f(X)]
        Specifically: log(E[X]) ≥ E[log(X)]
        """
        print("=== Jensen's Inequality for Logarithm ===\n")
        
        # Test with different distributions
        test_cases = [
            ("Uniform [1,4]", torch.distributions.Uniform(1, 4)),
            ("Exponential(1)", torch.distributions.Exponential(1.0)),
            ("Gamma(2,1)", torch.distributions.Gamma(2.0, 1.0))
        ]
        
        n_samples = 5000
        
        print("Testing Jensen's inequality: log(E[X]) ≥ E[log(X)]")
        print()
        
        for name, dist in test_cases:
            # TODO: Sample from distribution
            samples = dist.sample((n_samples,))
            
            # TODO: Compute both sides of Jensen's inequality
            log_expectation = torch.log(samples.mean())
            expectation_log = torch.log(samples).mean()
            gap = log_expectation - expectation_log
            
            print(f"{name:15s}: log(E[X])={log_expectation:.4f}, E[log(X)]={expectation_log:.4f}, gap={gap:.4f}")
        
        print(f"\nKey insight: The gap = log(E[X]) - E[log(X)] is our lower bound!")
        print(f"This gap becomes the tractable ELBO bound on intractable log p(x)")
    
    def create_jensen_visualization(self):
        """
        Create geometric visualization of Jensen's inequality
        
        Show why concave functions create lower bounds.
        """
        print(f"\n=== Jensen's Inequality Visualization ===")
        
        # Create x values and compute log function
        x = torch.linspace(0.5, 4, 1000)
        log_x = torch.log(x)
        
        # Pick two points and show linear interpolation vs function value
        x1, x2 = 1.0, 3.0
        lambda_val = 0.3
        
        # Compute the key values
        x_mix = lambda_val * x1 + (1 - lambda_val) * x2
        log_x_mix = torch.log(torch.tensor(x_mix))
        mix_log = lambda_val * torch.log(torch.tensor(x1)) + (1 - lambda_val) * torch.log(torch.tensor(x2))
        
        # Create the plot
        plt.figure(figsize=(10, 6))
        plt.plot(x, log_x, 'b-', linewidth=3, label='log(x) [concave]')
        
        # Add points and lines showing Jensen's inequality
        plt.plot([x1, x2], [torch.log(torch.tensor(x1)), torch.log(torch.tensor(x2))], 'ro', markersize=8)
        plt.plot([x1, x2], [torch.log(torch.tensor(x1)), torch.log(torch.tensor(x2))], 'r--', alpha=0.7, label='Linear interpolation')
        plt.plot(x_mix, log_x_mix, 'go', markersize=10, label=f'log(E[X]) = {log_x_mix:.3f}')
        plt.plot(x_mix, mix_log, 'mo', markersize=10, label=f'E[log(X)] = {mix_log:.3f}')
        
        # Add arrow showing the gap
        plt.annotate('', xy=(x_mix, log_x_mix), xytext=(x_mix, mix_log),
                    arrowprops=dict(arrowstyle='<->', color='red', lw=2))
        plt.text(x_mix + 0.1, (log_x_mix + mix_log)/2, f'Gap = {log_x_mix - mix_log:.3f}',
                fontsize=12, color='red', weight='bold')
        
        plt.xlabel('x')
        plt.ylabel('log(x)')
        plt.title('Jensen\'s Inequality: log(E[X]) ≥ E[log(X)]')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
        
        # Print the verification
        print(f"Jensen verification: {log_x_mix:.4f} ≥ {mix_log:.4f} ? {log_x_mix >= mix_log}")

# Test Jensen's inequality
jensen = JensensInequalityDemo()
jensen.demonstrate_concave_property()
jensen.create_jensen_visualization()

---

## Part 5: ELBO Framework Implementation (25 minutes)

### Task 5.1: Build the Complete ELBO

**Your Mission**: Implement the ELBO framework that makes variational inference practical.

In [None]:
class ELBOFramework:
    """
    Implement the complete ELBO framework from scratch.
    This is the universal solution for generative model training.
    """
    
    def __init__(self, latent_dim: int = 2, data_dim: int = 2):
        self.latent_dim = latent_dim
        self.data_dim = data_dim
        
        # VAE components (architectures provided, you'll implement the math)
        self.encoder = nn.Sequential(
            nn.Linear(data_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 2 * latent_dim)  # Output mean and logvar
        ).to(device)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 32),
            nn.ReLU(),
            nn.Linear(32, data_dim)
        ).to(device)
        
        self.decoder_logvar = nn.Parameter(torch.zeros(1))
    
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract variational parameters q(z|x) = N(μ(x), σ²(x))"""
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=-1)
        return mu, logvar
    
    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement the reparameterization trick
        
        Sample z ~ N(μ, σ²I) using reparameterization: z = μ + σ * ε where ε ~ N(0,I)
        
        This makes random sampling differentiable!
        
        Args:
            mu: Mean parameters
            logvar: Log variance parameters (logvar = log(σ²))
            
        Returns:
            z: Reparameterized samples
        """
        # TODO: Implement reparameterization
        # Step 1: Convert logvar to std: σ = exp(0.5 * logvar)
        # Step 2: Sample ε ~ N(0,I)  
        # Step 3: Compute z = μ + σ * ε
        pass
    
    def reconstruction_loss(self, x: torch.Tensor, x_recon_mu: torch.Tensor, 
                           x_recon_logvar: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement reconstruction term E_q[log p(x|z)]
        
        Compute the expected log-likelihood of data under the decoder.
        Model: p(x|z) = N(x; μ_decoder(z), σ²_decoder)
        
        Args:
            x: Original data
            x_recon_mu: Decoder mean output
            x_recon_logvar: Decoder log variance
            
        Returns:
            reconstruction_loss: E_q[log p(x|z)] for each sample
        """
        # TODO: Implement Gaussian log-likelihood
        # Formula: log p(x|z) = -0.5 * [log(2π) + log(σ²) + (x-μ)²/σ²]
        # Remember to sum over data dimensions
        pass
    
    def kl_regularization(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        TODO: Implement KL regularization term KL(q(z|x) || p(z))
        
        Compute KL divergence between variational posterior and prior.
        Assume p(z) = N(0, I) and q(z|x) = N(μ(x), σ²(x)I)
        
        Args:
            mu: Variational mean
            logvar: Variational log variance
            
        Returns:
            kl_loss: KL divergence for each sample
        """
        # TODO: Use your standard_normal_kl implementation from earlier
        # Or implement directly: KL = 0.5 * Σ[μ² + σ² - 1 - log(σ²)]
        pass
    
    def compute_elbo(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        TODO: Implement the complete ELBO computation
        
        ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z))
             = Reconstruction - Regularization
        
        Walk through the complete mathematical derivation step by step.
        
        Args:
            x: Input data
            
        Returns:
            Dictionary with ELBO components
        """
        # TODO: Step 1 - Encode to get variational parameters
        mu_q, logvar_q = self.encode(x)
        
        # TODO: Step 2 - Sample using reparameterization trick
        z = self.reparameterize(mu_q, logvar_q)
        
        # TODO: Step 3 - Decode to get reconstruction parameters
        x_recon_mu = self.decoder(z)
        x_recon_logvar = self.decoder_logvar.expand_as(x_recon_mu)
        
        # TODO: Step 4 - Compute reconstruction term
        recon_term = self.reconstruction_loss(x, x_recon_mu, x_recon_logvar)
        
        # TODO: Step 5 - Compute regularization term
        kl_term = self.kl_regularization(mu_q, logvar_q)
        
        # TODO: Step 6 - Combine for ELBO
        elbo = recon_term - kl_term
        
        return {
            'elbo': elbo,
            'reconstruction': recon_term,
            'kl_divergence': kl_term,
            'mu': mu_q,
            'logvar': logvar_q,
            'z': z,
            'x_recon': x_recon_mu
        }
    
    def train_vae(self, data: torch.Tensor, epochs: int = 50, lr: float = 1e-3):
        """
        Train VAE using ELBO - demonstration of practical optimization
        
        Show how mathematical theory becomes practical optimization.
        """
        print("=== Training VAE with ELBO ===\n")
        
        # Set up optimizer
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        
        losses = {'total': [], 'reconstruction': [], 'kl': []}
        
        for epoch in range(epochs):
            # Training step
            optimizer.zero_grad()
            
            # Compute ELBO (after students implement compute_elbo)
            results = self.compute_elbo(data)
            elbo = results['elbo']
            
            # Loss is negative ELBO (we want to maximize ELBO)
            loss = -elbo.mean()
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            # Track losses
            losses['total'].append(loss.item())
            losses['reconstruction'].append(-results['reconstruction'].mean().item())
            losses['kl'].append(results['kl_divergence'].mean().item())
            
            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Loss = {loss.item():.3f}, Recon = {losses['reconstruction'][-1]:.3f}, KL = {losses['kl'][-1]:.3f}")
        
        # Plot training curves
        self.plot_training_curves(losses)
        return losses
    
    def plot_training_curves(self, losses):
        """Visualize training progress"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        epochs = range(len(losses['total']))
        
        # Plot 1: Total loss
        ax1.plot(epochs, losses['total'], 'b-', linewidth=2, label='Total Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training Loss (Negative ELBO)')
        ax1.grid(True, alpha=0.3)
        ax1.legend()
        
        # Plot 2: ELBO components
        ax2.plot(epochs, losses['reconstruction'], 'r-', linewidth=2, label='Reconstruction')
        ax2.plot(epochs, losses['kl'], 'g-', linewidth=2, label='KL Divergence')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss Component')
        ax2.set_title('ELBO Components')
        ax2.grid(True, alpha=0.3)
        ax2.legend()
        
        plt.tight_layout()
        plt.show()
        
        print("Training insights:")
        print(f"• Final reconstruction loss: {losses['reconstruction'][-1]:.3f}")
        print(f"• Final KL divergence: {losses['kl'][-1]:.3f}")
        print(f"• Balance shows reconstruction vs regularization tradeoff")
    
    def parameters(self):
        """Helper to get all model parameters"""
        params = list(self.encoder.parameters()) + list(self.decoder.parameters()) + [self.decoder_logvar]
        return params

# Test your ELBO implementation (uncomment after implementing TODOs)
# elbo_model = ELBOFramework(latent_dim=2, data_dim=2)

# Test ELBO computation
# print("Testing ELBO computation...")
# results = elbo_model.compute_elbo(test_data[:10])
# print("✓ ELBO computation successful!")

# Train the VAE
# print("Training VAE...")
# losses = elbo_model.train_vae(test_data[:100], epochs=30)

### Task 5.2: The Two Forces Analysis

**Your Mission**: Implement the analysis of reconstruction vs regularization tradeoff.

In [None]:
class TwoForcesAnalysis:
    """
    Analyze the fundamental tradeoff in variational inference.
    You'll implement the β-VAE analysis to understand force balance.
    """
    
    def __init__(self):
        self.kl_builder = KLDivergenceBuilder()
    
    def analyze_beta_vae(self, data: torch.Tensor, beta_values: list = [0.0, 0.1, 1.0, 5.0, 10.0]):
        """
        TODO: Implement β-VAE analysis
        
        Test different weightings: Loss = -Reconstruction + β * KL
        Show how β controls the reconstruction vs regularization tradeoff.
        
        Args:
            data: Training data
            beta_values: Different β weights to test
        """
        print("=== Two Forces Analysis: β-VAE Experiment ===\n")
        
        results = {}
        
        for beta in beta_values:
            print(f"Training with β = {beta}")
            
            # Create a new model for each β
            model = ELBOFramework(latent_dim=2, data_dim=2)
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
            
            # Train with modified loss
            for epoch in range(20):  # Shorter training for comparison
                optimizer.zero_grad()
                
                # TODO: Students implement compute_elbo, then this works
                elbo_results = model.compute_elbo(data)
                reconstruction = elbo_results['reconstruction'].mean()
                kl_divergence = elbo_results['kl_divergence'].mean()
                
                # Modified loss with β weighting
                loss = -reconstruction + beta * kl_divergence
                
                # Optimize
                loss.backward()
                optimizer.step()
            
            # Evaluate final model
            with torch.no_grad():
                final_results = model.compute_elbo(data)
                results[beta] = {
                    'reconstruction': final_results['reconstruction'].mean().item(),
                    'kl_divergence': final_results['kl_divergence'].mean().item(),
                    'latent_samples': final_results['z'].cpu(),
                    'reconstructions': final_results['x_recon'].cpu()
                }
        
        # Visualize the results
        self.visualize_beta_effects(results, data)
    
    def visualize_beta_effects(self, results: dict, original_data: torch.Tensor):
        """
        Create visualizations showing β effects
        
        Show how different β values affect:
        1. Reconstruction vs KL tradeoff
        2. Latent space organization  
        3. Reconstruction quality
        """
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        
        # Plot 1: Metrics vs β
        betas = list(results.keys())
        recon_losses = [results[b]['reconstruction'] for b in betas]
        kl_losses = [results[b]['kl_divergence'] for b in betas]
        
        ax1.plot(betas, recon_losses, 'ro-', linewidth=2, markersize=8, label='Reconstruction Loss')
        ax1.set_xlabel('β (KL weight)')
        ax1.set_ylabel('Reconstruction Loss', color='r')
        ax1.tick_params(axis='y', labelcolor='r')
        ax1.grid(True, alpha=0.3)
        
        ax1_twin = ax1.twinx()
        ax1_twin.plot(betas, kl_losses, 'bo-', linewidth=2, markersize=8, label='KL Divergence')
        ax1_twin.set_ylabel('KL Divergence', color='b')
        ax1_twin.tick_params(axis='y', labelcolor='b')
        
        ax1.set_title('Two Forces Tradeoff')
        ax1.legend(loc='upper left')
        ax1_twin.legend(loc='upper right')
        
        # Plot 2: Latent space organization for different β
        colors = ['red', 'orange', 'green', 'blue', 'purple']
        for i, (beta, color) in enumerate(zip([0.0, 1.0, 10.0], colors[:3])):
            if beta in results:
                z_samples = results[beta]['latent_samples']
                ax2.scatter(z_samples[:, 0], z_samples[:, 1], alpha=0.6, s=20, 
                           color=color, label=f'β={beta}')
        
        ax2.set_xlabel('Latent Dimension 1')
        ax2.set_ylabel('Latent Dimension 2')
        ax2.set_title('Latent Space Organization')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Original data
        ax3.scatter(original_data[:, 0].cpu(), original_data[:, 1].cpu(), 
                   alpha=0.6, s=20, color='black')
        ax3.set_title('Original Data')
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Reconstruction examples for different β
        sample_indices = torch.randperm(len(original_data))[:5]
        for i, beta in enumerate([0.0, 1.0, 10.0]):
            if beta in results:
                recons = results[beta]['reconstructions'][sample_indices]
                ax4.scatter(recons[:, 0], recons[:, 1], alpha=0.8, s=60,
                           color=colors[i], label=f'β={beta} recons')
        
        # Show original points for comparison
        orig_sample = original_data[sample_indices]
        ax4.scatter(orig_sample[:, 0].cpu(), orig_sample[:, 1].cpu(), 
                   alpha=0.8, s=60, color='black', marker='x', s=100, label='Original')
        
        ax4.set_title('Reconstruction Quality')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        print("\nKey insights from β-VAE analysis:")
        print("• β = 0: Perfect reconstruction, random latent structure")
        print("• β = 1: Balanced tradeoff (standard VAE)")  
        print("• β >> 1: Structured latents, poor reconstruction")
        print("• The 'sweet spot' depends on your application goals")
    
    def explain_two_forces(self):
        """
        Explain the fundamental two-forces concept
        """
        print("\n=== Understanding the Two Forces ===\n")
        
        print("ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z))")
        print("     = Reconstruction  - Regularization")
        print()
        
        print("Force 1 - Reconstruction: E_q[log p(x|z)]")
        print("  • Wants: Perfect reconstruction of input data")
        print("  • Pushes: Encoder to preserve all information")
        print("  • Says: 'Use whatever latent codes work best!'")
        print()
        
        print("Force 2 - Regularization: KL(q(z|x) || p(z))")
        print("  • Wants: Latent codes close to prior distribution")
        print("  • Pushes: Encoder toward structured representations")
        print("  • Says: 'Stay close to simple distributions!'")
        print()
        
        print("The beautiful tension:")
        print("  ⚖️  Perfect reconstruction vs structured latents")
        print("  💎 This tension creates meaningful representations")
        print("  🎯 Balance determines model behavior")
        print()
        
        # Create a conceptual visualization
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        
        # Draw the two forces as arrows
        ax.arrow(0.2, 0.5, 0.25, 0, head_width=0.05, head_length=0.02, 
                fc='red', ec='red', linewidth=3)
        ax.text(0.32, 0.6, 'Reconstruction\nForce', ha='center', fontsize=12, 
                color='red', weight='bold')
        
        ax.arrow(0.8, 0.5, -0.25, 0, head_width=0.05, head_length=0.02,
                fc='blue', ec='blue', linewidth=3)
        ax.text(0.68, 0.6, 'Regularization\nForce', ha='center', fontsize=12,
                color='blue', weight='bold')
        
        # Show the balance point
        ax.plot(0.5, 0.5, 'go', markersize=15, label='Optimal Balance')
        ax.text(0.5, 0.35, 'ELBO\nOptimum', ha='center', fontsize=12,
                color='green', weight='bold')
        
        ax.set_xlim(0, 1)
        ax.set_ylim(0.2, 0.8)
        ax.set_title('The Two Forces in ELBO', fontsize=16, weight='bold')
        ax.axis('off')
        
        plt.tight_layout()
        plt.show()

# Run two forces analysis (uncomment after implementing TODOs)
two_forces = TwoForcesAnalysis()
two_forces.explain_two_forces()
# two_forces.analyze_beta_vae(test_data[:50])

---

## Part 6: Mathematical Validation and Testing (10 minutes)

### Task 6.1: Validate Your Implementations

**Your Mission**: Test that your mathematical implementations are correct.

In [None]:
def comprehensive_validation():
    """
    Implement comprehensive testing of all mathematical components
    
    Verify that your implementations match theoretical predictions.
    """
    print("=== Comprehensive Mathematical Validation ===\n")
    
    # Test 1: KL Divergence Validation
    print("Test 1: KL Divergence Implementation")
    
    # Create test distributions
    mu1, sigma1 = torch.tensor(0.5), torch.tensor(0.8)
    mu2, sigma2 = torch.tensor(0.0), torch.tensor(1.0)
    
    # Test your implementation against known formula
    kl_builder = KLDivergenceBuilder()
    kl_implemented = kl_builder.gaussian_kl_closed_form(mu1, sigma1, mu2, sigma2)
    
    # Manual calculation for verification
    kl_manual = 0.5 * ((sigma1/sigma2)**2 + ((mu1-mu2)/sigma2)**2 - 1 - 2*torch.log(sigma1/sigma2))
    
    print(f"Your implementation: {kl_implemented}")
    print(f"Manual calculation:  {kl_manual:.6f}")
    print(f"Match: {torch.allclose(kl_implemented, kl_manual, atol=1e-5) if kl_implemented is not None else 'Implement TODO first'}")
    
    # Test 2: Jensen's Inequality
    print(f"\nTest 2: Jensen's Inequality")
    
    # Test with simple example
    values = torch.tensor([1.0, 4.0, 9.0])
    log_mean = torch.log(values.mean())
    mean_log = torch.log(values).mean()
    
    print(f"log(E[X]): {log_mean:.6f}")
    print(f"E[log(X)]: {mean_log:.6f}")
    print(f"Jensen satisfied: {log_mean >= mean_log}")
    print(f"Gap (lower bound): {log_mean - mean_log:.6f}")
    
    # Test 3: ELBO Decomposition
    print(f"\nTest 3: ELBO Components")
    
    # Test that ELBO = reconstruction - KL
    elbo_model = ELBOFramework(latent_dim=2, data_dim=2)
    test_x = torch.randn(5, 2).to(device)
    
    # Verify ELBO decomposition (after students implement compute_elbo)
    try:
        results = elbo_model.compute_elbo(test_x)
        elbo_direct = results['elbo']
        elbo_components = results['reconstruction'] - results['kl_divergence']
        
        print(f"ELBO direct: {elbo_direct.mean():.6f}")
        print(f"ELBO from components: {elbo_components.mean():.6f}")
        print(f"Match: {torch.allclose(elbo_direct, elbo_components, atol=1e-5)}")
    except:
        print("ELBO test pending - implement compute_elbo first")
    
    # Test 4: Reparameterization Trick
    print(f"\nTest 4: Reparameterization Trick")
    
    # Test that reparameterized samples have correct statistics
    mu_test = torch.tensor([1.0, -0.5]).to(device)
    logvar_test = torch.tensor([0.5, -0.2]).to(device)
    
    # Generate many samples and check statistics
    n_samples = 10000
    try:
        samples = []
        for _ in range(n_samples):
            sample = elbo_model.reparameterize(mu_test.unsqueeze(0), logvar_test.unsqueeze(0))
            samples.append(sample.squeeze())
        
        samples = torch.stack(samples)
        empirical_mean = samples.mean(dim=0)
        empirical_var = samples.var(dim=0)
        expected_var = torch.exp(logvar_test)
        
        print(f"Expected mean: {mu_test}")
        print(f"Empirical mean: {empirical_mean}")
        print(f"Expected var: {expected_var}")
        print(f"Empirical var: {empirical_var}")
        print(f"Mean close: {torch.allclose(mu_test, empirical_mean, atol=0.1)}")
        print(f"Var close: {torch.allclose(expected_var, empirical_var, atol=0.1)}")
    except:
        print("Reparameterization test pending - implement reparameterize first")
    
    print("\n🎓 Validation completed!")
    print("If your implementations pass these tests, your mathematical foundation is solid!")

# Run comprehensive validation
comprehensive_validation()

---

## Part 7: Connection to Diffusion Models (5 minutes)

### Task 7.1: Bridge to Diffusion

In [None]:
class DiffusionMathematicalBridge:
    """
    Connect today's mathematical foundations to diffusion models.
    Show how the ELBO framework extends to sequence generation.
    """
    
    def compare_vae_vs_diffusion(self):
        """
        Compare mathematical frameworks
        """
        print("=== VAE vs Diffusion: Mathematical Comparison ===\n")
        
        print("VAE Framework:")
        print("  • Latent space: Single z ~ q(z|x)")
        print("  • Variational distribution: Learn q(z|x) with encoder")
        print("  • ELBO: E[log p(x|z)] - KL(q(z|x) || p(z))")
        print("  • Challenge: Approximation quality of q(z|x)")
        print()
        
        print("Diffusion Framework:")
        print("  • Latent space: Sequence x₁, x₂, ..., xₜ")
        print("  • Variational distribution: Fixed q(x₁:ₜ|x₀) by design")
        print("  • ELBO: More complex, but perfectly tractable")
        print("  • Brilliance: No approximation error in q!")
        print()
        
        print("Key insight: Diffusion eliminates the variational approximation problem!")
        
        # Create comparison visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
        
        # VAE diagram
        ax1.text(0.5, 0.8, 'VAE', ha='center', fontsize=16, weight='bold')
        ax1.text(0.2, 0.6, 'x', ha='center', fontsize=14, bbox=dict(boxstyle="round", facecolor='lightblue'))
        ax1.text(0.5, 0.6, 'Encoder\nq(z|x)', ha='center', fontsize=12)
        ax1.text(0.8, 0.6, 'z', ha='center', fontsize=14, bbox=dict(boxstyle="round", facecolor='lightgreen'))
        ax1.text(0.5, 0.4, 'Decoder\np(x|z)', ha='center', fontsize=12)
        ax1.text(0.2, 0.2, 'x̂', ha='center', fontsize=14, bbox=dict(boxstyle="round", facecolor='lightcoral'))
        
        ax1.arrow(0.25, 0.6, 0.15, 0, head_width=0.02, head_length=0.02, fc='black', ec='black')
        ax1.arrow(0.75, 0.6, -0.15, 0, head_width=0.02, head_length=0.02, fc='black', ec='black')
        ax1.arrow(0.45, 0.35, -0.15, -0.1, head_width=0.02, head_length=0.02, fc='black', ec='black')
        
        ax1.set_xlim(0, 1)
        ax1.set_ylim(0, 1)
        ax1.axis('off')
        ax1.set_title('Single-Step Generation')
        
        # Diffusion diagram
        ax2.text(0.5, 0.9, 'Diffusion', ha='center', fontsize=16, weight='bold')
        
        positions = [0.1, 0.3, 0.5, 0.7, 0.9]
        labels = ['x₀', 'x₁', '...', 'xₜ₋₁', 'xₜ']
        
        for i, (pos, label) in enumerate(zip(positions, labels)):
            if i == 2:  # dots
                ax2.text(pos, 0.5, label, ha='center', fontsize=14)
            else:
                ax2.text(pos, 0.5, label, ha='center', fontsize=12, 
                        bbox=dict(boxstyle="round", facecolor='lightblue'))
            
            if i < len(positions) - 1 and i != 1:  # skip arrow after dots
                ax2.arrow(pos + 0.05, 0.5, 0.1, 0, head_width=0.02, head_length=0.02, 
                         fc='black', ec='black')
        
        ax2.set_xlim(0, 1)
        ax2.set_ylim(0.3, 1)
        ax2.axis('off')
        ax2.set_title('Sequential Generation')
        
        plt.tight_layout()
        plt.show()
    
    def preview_diffusion_elbo(self):
        """
        Preview how ELBO extends to sequences
        """
        print("=== Preview: Diffusion ELBO ===\n")
        
        print("VAE ELBO (single step):")
        print("  log p(x) ≥ E_q[log p(x|z)] - KL(q(z|x) || p(z))")
        print()
        
        print("Diffusion ELBO (sequence):")
        print("  log p(x₀) ≥ E_q[log p(x₀|x₁)] - Σₜ KL(...)")
        print("           = Three terms we'll learn next time:")
        print("           = Reconstruction + Prior Matching + Consistency")
        print()
        
        print("Mathematical progression:")
        print("  1. Today: Single latent variable z")
        print("  2. Next: Sequence of latent variables x₁, x₂, ..., xₜ")
        print("  3. Same ELBO principle, extended to Markov chains")
        print("  4. Fixed variational distribution eliminates approximation error")
        print()
        
        print("Next lab will show:")
        print("  • How to derive the diffusion ELBO")
        print("  • Why noise prediction becomes natural")
        print("  • How your math today enables diffusion training")

# Connect to diffusion models
bridge = DiffusionMathematicalBridge()
bridge.compare_vae_vs_diffusion()
bridge.preview_diffusion_elbo()

---

## Part 8: Reflection and Summary (5 minutes)

### Task 8.1: Mathematical Insights Summary

**Discussion Questions** (Work with your partner):

1. **Implementation Challenges**:
   - Which mathematical concept was most difficult to implement?
   - How did implementing the formulas change your understanding?

2. **ELBO Insights**:
   - Why does Jensen's inequality create a useful lower bound?
   - How do the two forces shape learned representations?

3. **Practical Connections**:
   - How does mathematical theory connect to practical training?
   - What would happen if we didn't have the reparameterization trick?

In [None]:
def summarize_mathematical_journey():
    """
    Reflect on the mathematical concepts you've implemented
    """
    print("=== Your Mathematical Journey Today ===\n")
    
    concepts_implemented = [
        "❌ Intractable likelihood computation (and why it fails)",
        "🔄 KL divergence calculations (asymmetric properties)",
        "📐 Jensen's inequality (creating lower bounds)",
        "🧮 ELBO framework (universal variational inference)",
        "⚖️  Two forces analysis (reconstruction vs regularization)",
        "🔧 Reparameterization trick (making randomness differentiable)"
    ]
    
    print("Mathematical concepts you implemented:")
    for concept in concepts_implemented:
        print(f"  {concept}")
    
    print(f"\n🎓 You now understand the mathematical foundation of:")
    print(f"   • Variational Autoencoders (VAEs)")
    print(f"   • Generative Adversarial Networks (GANs) - different approach")
    print(f"   • Diffusion Models - optimal variational choice")
    print(f"   • All modern generative models!")
    
    print(f"\n🔬 Key mathematical insights:")
    print(f"   • Why direct likelihood optimization fails")
    print(f"   • How Jensen's inequality saves the day")
    print(f"   • Why the two-forces tension creates good representations")
    print(f"   • How mathematical elegance enables practical algorithms")
    
    # Create a summary visualization
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    
    # Mathematical concept hierarchy
    levels = {
        'Problem': ['Intractable p(x)', 'Circular Bayes Rule'],
        'Solution': ['Jensen\'s Inequality', 'KL Divergence'],
        'Framework': ['ELBO = Recon - KL'],
        'Implementation': ['Reparameterization', 'Two Forces Balance'],
        'Applications': ['VAEs', 'Diffusion Models']
    }
    
    y_positions = [0.8, 0.65, 0.5, 0.35, 0.2]
    colors = ['red', 'orange', 'green', 'blue', 'purple']
    
    for i, (level, concepts) in enumerate(levels.items()):
        y = y_positions[i]
        color = colors[i]
        
        # Level label
        ax.text(0.05, y, level, fontsize=14, weight='bold', color=color)
        
        # Concepts
        x_positions = [0.25 + j * 0.2 for j in range(len(concepts))]
        for j, concept in enumerate(concepts):
            ax.text(x_positions[j], y, concept, fontsize=11, ha='center',
                   bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.3))
    
    # Add arrows showing progression
    for i in range(len(y_positions) - 1):
        ax.arrow(0.5, y_positions[i] - 0.05, 0, -0.05, head_width=0.02, head_length=0.01,
                fc='gray', ec='gray', alpha=0.7)
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('Mathematical Foundation Hierarchy', fontsize=16, weight='bold')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Summarize your learning
summarize_mathematical_journey()

---

## Implementation Checklist

### Core Mathematical Functions (Students Implement):

**✅ Essential TODOs:**
- [ ] `sample_prior()` - Prior sampling from N(0,I)
- [ ] `likelihood_given_z()` - Gaussian likelihood p(x|z)
- [ ] `approximate_marginal_likelihood()` - Monte Carlo estimation
- [ ] `monte_carlo_kl()` - KL via sampling
- [ ] `gaussian_kl_closed_form()` - Analytical KL for Gaussians
- [ ] `standard_normal_kl()` - KL to N(0,I)
- [ ] `reparameterize()` - Reparameterization trick
- [ ] `reconstruction_loss()` - ELBO reconstruction term
- [ ] `kl_regularization()` - ELBO KL term
- [ ] `compute_elbo()` - Complete ELBO computation
- [ ] `analyze_beta_vae()` - β-VAE force analysis

**✅ Provided Starter Code:**
- [ ] All visualization functions with complete plotting code
- [ ] Training loops and optimization
- [ ] Data generation and testing infrastructure
- [ ] Mathematical validation and comparison functions

---

## Submission Requirements

### What to Submit

Submit your completed Jupyter notebook (.ipynb file) with:

**✅ Mathematical Implementations:**
- All TODO functions implemented with correct mathematical formulas
- Clear comments explaining each mathematical step
- Proper tensor operations and broadcasting

**✅ Validation Results:**
- Numerical verification that implementations match theory
- Comparison of analytical vs Monte Carlo methods
- Testing edge cases and error handling

**✅ Analysis and Insights:**
- β-VAE analysis showing two forces tradeoff
- KL asymmetry exploration with explanations
- Jensen's inequality demonstrations

**✅ Documentation:**
- Clear explanations of each mathematical concept
- Discussion of implementation challenges encountered
- Connection between theory and practical algorithms

---

## Quick Reference: Key Mathematical Formulas

### For Implementation Reference:

**KL Divergence to Standard Normal:**

In [None]:
# KL(N(μ,σ²) || N(0,I)) = 0.5 * Σ[μ² + σ² - 1 - log(σ²)]
kl = 0.5 * (mu**2 + torch.exp(logvar) - 1 - logvar).sum(dim=-1)

**Gaussian Log-Likelihood:**

In [None]:
# log p(x|z) = -0.5 * [log(2π) + log(σ²) + (x-μ)²/σ²]
log_likelihood = -0.5 * (math.log(2*math.pi) + logvar + (x - mu)**2 / torch.exp(logvar))

**Reparameterization Trick:**

In [None]:
# z = μ + σ * ε, where ε ~ N(0,I) and σ = exp(0.5 * log(σ²))
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + std * eps

**Gaussian KL Closed Form:**

In [None]:
# KL(N(μ₁,σ₁²) || N(μ₂,σ₂²))
kl = torch.log(sigma2/sigma1) + (sigma1**2 + (mu1-mu2)**2)/(2*sigma2**2) - 0.5

---

## Common Implementation Issues & Solutions

### Debugging Tips:

**NaN Gradients:**
- Check for `log(0)` in likelihood computations
- Ensure positive variances: use `torch.exp(logvar)` not `logvar`
- Add small epsilon: `torch.exp(logvar) + 1e-8`

**Dimension Mismatches:**
- KL should sum over latent dimensions: `.sum(dim=-1)`
- Reconstruction should sum over data dimensions: `.sum(dim=-1)`
- Batch operations: keep batch dimension intact

**Training Instability:**
- Start with small learning rates (1e-4 to 1e-3)
- Monitor KL collapse: if KL → 0, increase β weight
- Check reconstruction scale: normalize data to [-1,1] or [0,1]

---

## Final Implementation Notes

### Expected Behavior After Implementation:

**Intractable Likelihood Demo:**
- Monte Carlo estimates should have high variance
- Computation time should grow with sample count
- Multiple runs should give different estimates

**KL Divergence:**
- `KL(p||q) ≠ KL(q||p)` - should see clear asymmetry
- Closed-form and Monte Carlo should match (within 0.01)
- Standard normal KL should be positive for non-standard inputs

**Jensen's Inequality:**
- `log(E[X]) ≥ E[log(X)]` should hold for all test cases
- Gap should be positive and meaningful (> 0.001)

**ELBO Training:**
- Loss should decrease over epochs
- Reconstruction and KL should show clear tradeoff
- β-VAE should show dramatic differences in latent organization

**Validation:**
- All mathematical relationships should hold within numerical precision
- Reparameterized samples should match expected statistics
- ELBO decomposition should be internally consistent

---
