# Variational Autoencoders (VAE) for MNIST Generation

## Overview

Variational Autoencoders combine deep learning with probabilistic modeling to learn compressed representations of data while enabling generation of new samples. This notebook implements VAE for MNIST digit generation.

## Key Concepts

### 1. Probabilistic Latent Variables
Unlike regular autoencoders, VAEs model the **latent space as a probability distribution**. Instead of encoding to a fixed point, we encode to a distribution $q(z|x)$ and sample from it.

### 2. The Reparameterization Trick
The key innovation that makes VAEs trainable:

**Problem**: Can't backpropagate through random sampling
**Solution**: Reparameterize the random variable

$$z \sim \mathcal{N}(\mu, \sigma^2) \quad \Rightarrow \quad z = \mu + \sigma \odot \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, I)$$

This allows gradients to flow through $\mu$ and $\sigma$ while maintaining the stochastic nature.

### 3. VAE Loss Function: ELBO
VAEs maximize the **Evidence Lower BOund (ELBO)**:

$$\mathcal{L} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x) || p(z))$$

**Components:**
- **Reconstruction Loss**: $\mathbb{E}_{q(z|x)}[\log p(x|z)]$ - How well can we reconstruct input?
- **KL Divergence**: $D_{KL}(q(z|x) || p(z))$ - How close is our encoding to the prior?

### 4. Why Loss is Bounded
Unlike flow models, VAE loss is **naturally bounded**:
- **Reconstruction Loss**: BCE ∈ [0, ∞) but practically bounded by data dimensionality
- **KL Divergence**: KL ∈ [0, ∞) but regularizes toward prior (typically small)
- **Total Loss**: Typically ranges from 100-500 for MNIST

### 5. Sampling and Generation
**Training**: $x \rightarrow \text{Encoder} \rightarrow q(z|x) \rightarrow z \rightarrow \text{Decoder} \rightarrow \hat{x}$
**Generation**: $z \sim p(z) \rightarrow \text{Decoder} \rightarrow x$

The prior $p(z) = \mathcal{N}(0, I)$ enables easy sampling for generation.

## Implementation: Imports and Data Preprocessing

VAEs work well with both discrete and continuous data, but we'll use the same dequantization as other models for consistency:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from keras.datasets.mnist import load_data
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

torch.manual_seed(0)

# Load and normalize MNIST dataset with DEQUANTIZATION
print("Loading MNIST dataset...")
(trainX, trainY), (testX, testy) = load_data()

# Dequantization: Add uniform noise to make discrete pixels continuous
trainX = (np.float32(trainX) + torch.rand(trainX.shape).numpy()) / 255.
trainX = trainX.clip(0, 1)  # Ensure values stay in [0,1]
trainX = torch.tensor(trainX.reshape(-1, 28 * 28))

print(f"Dataset shape: {trainX.shape}")
print(f"Pixel value range: [{trainX.min():.3f}, {trainX.max():.3f}]")

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## VAE Architecture: Encoder and Decoder

The VAE consists of two main components:
1. **Encoder**: Maps input $x$ to latent distribution parameters $(\mu, \log\sigma^2)$
2. **Decoder**: Maps latent sample $z$ back to reconstruction $\hat{x}$

### Key Design Choices:
- **Latent Dimension**: 20 (much smaller than 784 input dimensions)
- **Architecture**: Fully connected layers with ReLU activations
- **Output**: Sigmoid activation for pixel values in [0,1]

In [None]:
class VAE(nn.Module):
    """
    Variational Autoencoder for MNIST
    
    Architecture:
    - Encoder: x(784) → hidden(400) → μ,logvar(20)
    - Decoder: z(20) → hidden(400) → x̂(784)
    
    Key Innovation: Reparameterization trick allows backpropagation through sampling
    """
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        
        self.latent_dim = latent_dim
        
        # Encoder: x → μ, log(σ²)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Latent space parameters
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)      # Mean
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Log variance
        
        # Decoder: z → x̂
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Output in [0,1] for MNIST
        )
        
        print(f"Initialized VAE:")
        print(f"  Input dimension: {input_dim}")
        print(f"  Latent dimension: {latent_dim}")
        print(f"  Hidden dimension: {hidden_dim}")
        
    def encode(self, x):
        """
        Encode input to latent distribution parameters
        
        Args:
            x: Input data [batch_size, input_dim]
            
        Returns:
            mu: Mean of latent distribution [batch_size, latent_dim]
            logvar: Log variance of latent distribution [batch_size, latent_dim]
        """
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """
        THE REPARAMETERIZATION TRICK
        
        Instead of sampling z ~ N(μ, σ²) directly (non-differentiable),
        we reparameterize: z = μ + σ ⊙ ε where ε ~ N(0, I)
        
        This allows gradients to flow through μ and σ
        """
        std = torch.exp(0.5 * logvar)  # σ = exp(½ log σ²)
        eps = torch.randn_like(std)    # ε ~ N(0, I)
        return mu + eps * std          # z = μ + σ ⊙ ε
    
    def decode(self, z):
        """
        Decode latent sample to reconstruction
        
        Args:
            z: Latent sample [batch_size, latent_dim]
            
        Returns:
            x_recon: Reconstructed data [batch_size, input_dim]
        """
        return self.decoder(z)
    
    def forward(self, x):
        """
        Full VAE forward pass
        
        x → Encoder → q(z|x) → reparameterize → z → Decoder → x̂
        """
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

# Initialize model
vae = VAE().to(device)
print(f"\nModel parameters: {sum(p.numel() for p in vae.parameters()):,}")

## The Reparameterization Trick: A Closer Look

Let's visualize how the reparameterization trick works and why it's crucial for training:

In [None]:
# Demonstrate the reparameterization trick
print("=== Reparameterization Trick Demonstration ===")

# Simulate encoder outputs
batch_size, latent_dim = 32, 20
mu = torch.randn(batch_size, latent_dim) * 0.5      # Some mean values
logvar = torch.randn(batch_size, latent_dim) * 0.2  # Some log variance values

print(f"Encoder outputs:")
print(f"  μ range: [{mu.min():.3f}, {mu.max():.3f}]")
print(f"  log σ² range: [{logvar.min():.3f}, {logvar.max():.3f}]")

# Apply reparameterization trick
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std

print(f"\nReparameterization components:")
print(f"  σ = exp(½ log σ²) range: [{std.min():.3f}, {std.max():.3f}]")
print(f"  ε ~ N(0,I) range: [{eps.min():.3f}, {eps.max():.3f}]")
print(f"  z = μ + σε range: [{z.min():.3f}, {z.max():.3f}]")

# Show that gradients can flow
print(f"\nGradient flow:")
print(f"  μ requires_grad: {mu.requires_grad}")
print(f"  log σ² requires_grad: {logvar.requires_grad}")
print(f"  z requires_grad: {z.requires_grad}")

# Visualize the sampling process
plt.figure(figsize=(15, 4))

# Plot 1: Prior distribution N(0,1)
plt.subplot(1, 4, 1)
prior_samples = torch.randn(1000)
plt.hist(prior_samples.numpy(), bins=50, alpha=0.7, color='blue')
plt.title("Prior p(z) ~ N(0,1)")
plt.xlabel("z")
plt.ylabel("Frequency")

# Plot 2: Encoder mean μ
plt.subplot(1, 4, 2)
plt.hist(mu.flatten().detach().numpy(), bins=50, alpha=0.7, color='orange')
plt.title("Encoder Mean μ")
plt.xlabel("μ")
plt.ylabel("Frequency")

# Plot 3: Encoder std σ
plt.subplot(1, 4, 3)
plt.hist(std.flatten().detach().numpy(), bins=50, alpha=0.7, color='green')
plt.title("Encoder Std σ")
plt.xlabel("σ")
plt.ylabel("Frequency")

# Plot 4: Reparameterized samples z
plt.subplot(1, 4, 4)
plt.hist(z.flatten().detach().numpy(), bins=50, alpha=0.7, color='red')
plt.title("Samples z = μ + σε")
plt.xlabel("z")
plt.ylabel("Frequency")

plt.tight_layout()
plt.show()

print("\n✓ The reparameterization trick allows us to sample from q(z|x) while maintaining differentiability!")

## VAE Loss Function: Evidence Lower BOund (ELBO)

The VAE loss combines two terms that balance reconstruction quality and regularization:

### Mathematical Derivation

**Goal**: Maximize $\log p(x)$ (intractable)

**Solution**: Maximize ELBO (tractable lower bound)

$$\log p(x) \geq \mathcal{L} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x) || p(z))$$

**Implementation**:
1. **Reconstruction Loss**: Binary Cross Entropy $-\sum_i x_i \log \hat{x}_i + (1-x_i) \log (1-\hat{x}_i)$
2. **KL Divergence**: For $q(z|x) = \mathcal{N}(\mu, \sigma^2)$ and $p(z) = \mathcal{N}(0, I)$:

$$D_{KL} = \frac{1}{2} \sum_j (\mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1)$$

In [None]:
def vae_loss_function(recon_x, x, mu, logvar):
    """
    VAE Loss = Reconstruction Loss + KL Divergence
    
    Args:
        recon_x: Reconstructed data [batch_size, input_dim]
        x: Original data [batch_size, input_dim] 
        mu: Latent mean [batch_size, latent_dim]
        logvar: Latent log variance [batch_size, latent_dim]
        
    Returns:
        total_loss: Combined loss
        bce_loss: Reconstruction loss component  
        kld_loss: KL divergence component
    """
    
    # Reconstruction Loss: Binary Cross Entropy
    # Measures how well we can reconstruct the input
    bce_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL Divergence Loss: KL(q(z|x) || p(z))
    # Regularizes the latent space to be close to prior N(0,I)
    # For Gaussians: KL = ½ Σ(μ² + σ² - log(σ²) - 1)
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    total_loss = bce_loss + kld_loss
    
    return total_loss, bce_loss, kld_loss

# Demonstrate loss computation with example data
print("=== VAE Loss Function Demonstration ===")

# Create example batch
batch_size = 32
example_x = torch.rand(batch_size, 784).to(device)  # Example input

# Forward pass
with torch.no_grad():
    recon_x, mu, logvar = vae(example_x)
    total_loss, bce_loss, kld_loss = vae_loss_function(recon_x, example_x, mu, logvar)

print(f"Example loss computation:")
print(f"  Reconstruction Loss (BCE): {bce_loss.item():.2f}")
print(f"  KL Divergence: {kld_loss.item():.2f}")
print(f"  Total Loss: {total_loss.item():.2f}")
print(f"  Loss per sample: {total_loss.item()/batch_size:.2f}")

# Show typical loss ranges for MNIST
print(f"\nTypical VAE loss ranges for MNIST:")
print(f"  Reconstruction Loss: 100-300 (depends on reconstruction quality)")
print(f"  KL Divergence: 0-50 (regularization strength)")
print(f"  Total Loss: 100-350 (decreases during training)")
print(f"  ✓ VAE loss is naturally bounded unlike flow models!")

## Training and Generation Functions

Now let's implement the training loop and image generation functions:

In [None]:
def generate_images_vae(model, epoch, nb_data=10, latent_dim=20, device='cpu'):
    """
    Generate and display VAE sample images
    
    Process: z ~ N(0,I) → Decoder → x̂
    """
    model.eval()
    with torch.no_grad():
        # Sample from prior distribution N(0,I)
        z = torch.randn(nb_data * nb_data, latent_dim).to(device)
        samples = model.decode(z).cpu().numpy()
        
        fig, axs = plt.subplots(nb_data, nb_data, figsize=(10, 10))
        for i in range(nb_data):
            for j in range(nb_data):
                idx = i * nb_data + j
                axs[i, j].imshow(samples[idx].reshape(28, 28), cmap='gray')
                axs[i, j].set_xticks([])
                axs[i, j].set_yticks([])
        plt.suptitle(f'VAE Generated Images - Epoch {epoch}')
        plt.tight_layout()
        plt.show()
    model.train()


def train_vae(model, optimizer, dataloader, nb_epochs=50, device='cpu'):
    """
    Train the VAE model
    
    Training process:
    1. Forward pass: x → q(z|x) → z → p(x|z) → x̂
    2. Compute ELBO loss
    3. Backpropagate gradients
    4. Update parameters
    """
    training_losses = []
    
    model.train()
    for epoch in tqdm(range(nb_epochs), desc="Training VAE"):
        epoch_loss = 0
        epoch_bce = 0
        epoch_kld = 0
        
        for batch_idx, batch in enumerate(dataloader):
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Forward pass
            recon_batch, mu, logvar = model(batch)
            
            # Compute loss
            loss, bce, kld = vae_loss_function(recon_batch, batch, mu, logvar)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            epoch_bce += bce.item()
            epoch_kld += kld.item()
            
            # Print loss components for monitoring
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1:3d}, Batch {batch_idx:3d}: "
                      f"BCE={bce.item():8.1f}, KLD={kld.item():6.1f}, "
                      f"Total={loss.item():8.1f}")
        
        # Store average loss
        avg_loss = epoch_loss / len(dataloader.dataset)
        training_losses.append(avg_loss)
        
        # Generate images every 10 epochs
        if (epoch + 1) % 10 == 0:
            print(f"\nGenerating VAE images at epoch {epoch + 1}")
            generate_images_vae(model, epoch + 1, device=device)
    
    return training_losses

print("Training and generation functions defined!")

## Training the VAE Model

Let's train our VAE and observe how the loss components evolve:

In [None]:
# Training setup
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
dataloader = DataLoader(trainX, batch_size=128, shuffle=True)

# Test VAE forward pass
print("Testing VAE forward pass...")
with torch.no_grad():
    test_batch = next(iter(dataloader))[:5].to(device)
    recon, mu, logvar = vae(test_batch)
    test_loss, test_bce, test_kld = vae_loss_function(recon, test_batch, mu, logvar)
    
    print(f"Test batch:")
    print(f"  Input shape: {test_batch.shape}")
    print(f"  Reconstruction shape: {recon.shape}")
    print(f"  Latent μ shape: {mu.shape}")
    print(f"  Latent log σ² shape: {logvar.shape}")
    print(f"  Test loss: {test_loss.item():.2f}")

# Train VAE
print("\nStarting VAE training...")
training_losses = train_vae(vae, optimizer, dataloader, nb_epochs=30, device=device)

print("\nVAE training completed!")

## Latent Space Analysis

Let's explore the learned latent space and demonstrate interpolation:

In [None]:
# Analyze the learned latent space
print("=== Latent Space Analysis ===")

vae.eval()
with torch.no_grad():
    # Encode some test samples
    test_samples = trainX[:1000].to(device)
    mu_encoded, logvar_encoded = vae.encode(test_samples)
    
    print(f"Encoded latent statistics:")
    print(f"  μ mean: {mu_encoded.mean():.3f}, std: {mu_encoded.std():.3f}")
    print(f"  log σ² mean: {logvar_encoded.mean():.3f}, std: {logvar_encoded.std():.3f}")
    
    # Sample from prior and decode
    prior_samples = torch.randn(100, 20).to(device)
    generated = vae.decode(prior_samples)
    
    print(f"Generated samples statistics:")
    print(f"  Generated pixel mean: {generated.mean():.3f}, std: {generated.std():.3f}")
    print(f"  Generated range: [{generated.min():.3f}, {generated.max():.3f}]")

# Demonstrate latent space interpolation
print("\n=== Latent Space Interpolation ===")

def interpolate_latent(vae, z1, z2, steps=10):
    """Interpolate between two latent points"""
    alphas = torch.linspace(0, 1, steps)
    interpolations = []
    
    for alpha in alphas:
        z_interp = (1 - alpha) * z1 + alpha * z2
        x_interp = vae.decode(z_interp.unsqueeze(0))
        interpolations.append(x_interp.cpu().numpy())
    
    return interpolations

# Create interpolation between random points
with torch.no_grad():
    z1 = torch.randn(20).to(device)
    z2 = torch.randn(20).to(device)
    
    interpolations = interpolate_latent(vae, z1, z2, steps=10)
    
    # Visualize interpolation
    fig, axes = plt.subplots(1, 10, figsize=(15, 2))
    for i, interp in enumerate(interpolations):
        axes[i].imshow(interp.reshape(28, 28), cmap='gray')
        axes[i].set_xticks([])
        axes[i].set_yticks([])
        axes[i].set_title(f'α={i/9:.1f}')
    
    plt.suptitle('Latent Space Interpolation: z₁ → z₂')
    plt.tight_layout()
    plt.show()

print("✓ Smooth interpolation demonstrates the continuous latent space!")

## Key Takeaways: VAE vs Other Generative Models

### Advantages of VAEs
1. **Principled Probabilistic Framework**: Based on solid statistical foundations
2. **Stable Training**: More stable than GANs, no adversarial dynamics
3. **Meaningful Latent Space**: Enables interpolation and latent arithmetic
4. **Bounded Loss**: Loss function has natural bounds, easier to interpret
5. **Both Inference and Generation**: Can encode data to latent space AND generate new data

### Limitations of VAEs
1. **Blurry Outputs**: BCE loss tends to produce blurry reconstructions
2. **Limited Expressiveness**: Variational approximation may be too restrictive
3. **No Exact Likelihood**: Unlike flows, provides only lower bound (ELBO)
4. **Posterior Collapse**: KL term may dominate, leading to uninformative latents

### Comparison Summary

| Aspect | VAE | Real NVP | GAN |
|--------|-----|----------|-----|
| **Likelihood** | Lower bound (ELBO) | Exact | None |
| **Training Stability** | Stable | Stable | Unstable |
| **Sample Quality** | Blurry | Good | Excellent |
| **Loss Bounds** | Bounded | Unbounded | Bounded |
| **Latent Space** | Meaningful | Meaningful | Not guaranteed |
| **Inference** | Fast | Slow | None |

### Best Practices for VAEs
- **β-VAE**: Weight KL term to control disentanglement
- **Warm-up**: Gradually increase KL weight during training
- **Architecture**: Use skip connections or more powerful decoders
- **Loss Functions**: Try alternatives to BCE (e.g., MSE, perceptual loss)

---

**References:**
- Kingma, D. P., & Welling, M. (2013). Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114.
- Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. arXiv preprint arXiv:1401.4082.