## Exercise VAEs

You need to install `torch` and `torchvision` to run the code in this notebook. You can do this via conda or pip:

```bash
pip install torch torchvision
```
or
```bash
conda install pytorch torchvision -c pytorch
```

Import the necessary libraries and download the MNIST dataset

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Configuration
BATCH_SIZE = 100
LEARNING_RATE = 1e-3
EPOCHS = 15
LATENT_DIM = 2  # We use 2 dimensions to easily visualize the latent space later
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

# Data Loading (MNIST)
transform = transforms.ToTensor() # Normalizes to [0, 1]
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Define the VAE model, the training loop, and visualize the results.

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=2):
        super(VAE, self).__init__()
        
        # --- Encoder ---
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        # We output two vectors: mean (mu) and log-variance (log_var)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # --- Decoder ---
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc_mu(h1), self.fc_logvar(h1)
    
    def reparameterize(self, mu, logvar):
        """
        THE REPARAMETERIZATION TRICK
        z = mu + sigma * epsilon
        """
        if self.training:
            std = torch.exp(0.5 * logvar) # convert log_var to standard deviation
            eps = torch.randn_like(std)   # sample epsilon from standard normal
            return mu + eps * std
        else:
            return mu

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3)) # Sigmoid because pixels are [0, 1]

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

model = VAE(latent_dim=LATENT_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
print(model)

In [None]:
def loss_function(recon_x, x, mu, logvar):
    # 1. Reconstruction Loss (Binary Cross Entropy)
    # reduction='sum' sums over the batch. We want the total loss.
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # 2. KL Divergence
    # Analytical solution for KL(q(z|x) || N(0,1))
    # Formula: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(DEVICE)
        optimizer.zero_grad()
        
        recon_batch, mu, logvar = model(data)
        
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)}] Loss: {loss.item()/len(data):.4f}')

    print(f'====> Epoch {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')

# Run training
for epoch in range(1, EPOCHS + 1):
    train(epoch)

In [None]:
def visualize_reconstruction():
    model.eval()
    with torch.no_grad():
        data, _ = next(iter(test_loader))
        data = data.to(DEVICE)
        recon, _, _ = model(data)
        
        # TODO: Visualize original and reconstructed images

visualize_reconstruction()