# Exercise 3: Variational Autoencoders
---
![VAE-cover](images/vae_cover.png)
<!-- ## Table of Contents
1. [Introduction](#introduction)
2. [MNIST](#part-i-mnist)
    - [](#tool-pytorch-lightning) -->

## Introduction

In this exercise, you will build and apply variational autoencoders to the [MNIST dataset](https://yann.lecun.com/exdb/mnist/) (hand-written digits). While VAEs may no longer top the race in terms of generation quality compared to newer frameworks such as diffusion models, they remain a critical tool in various applications. VAEs are not only useful for data generation but also for tasks like anomaly detection, semi-supervised learning, and feature selection. In reinforcement learning, for instance, VAEs are employed to learn compact, informative representations of the environment, which can simplify state-space representations and improve policy learning. They also play a crucial role in disentangling latent features, aiding in interpretable and controllable generative processes.

> MNIST has been used for many different projects, from handwritting recognition tasks to generative AI (such as this one!). Check out the [link](https://yann.lecun.com/exdb/mnist/) to see some of these projects, and maybe even try a couple yourself.

In [None]:
import torch
import torch.nn as nn

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid

In [None]:
import torch
from torch.utils.data import Dataset
from tqdm.notebook import tqdm
import numpy as np

import shutil
import os

shutil.copytree('kaggle/input/ae4353-3/additional', '/kaggle/working/additional')

from additional.plots import plot_reconstructions

## Setting up the dataset and training parameters:

Before building the model, we need to define some basic parameters and understand the shape of our data. MNIST images are ***28x28 pixels*** and since they are greyscale, ***they only have a single channel***. 

We also need to decide the size of the ***latent space*** - This is the dimensionality of the vector where the encoder compresses the input. A common choice is ***20 dimensions***, but you can experiment with smaller or larger latent spaces to see how it affects the quality of the generated images. This is considered the ***bottleneck*** of the VAE - The compressed representation of the data. ***Too small*** and the model cannot capture the variations in the data, reconstruction will lose details. ***too large*** and the latent space may become under-regularized, leading to poor generative properties (e.g., sampling from the latent space produces meaningless images).

> Practical rule for MNIST: 10-50 dimensions often work well (hence why we reocmmend first trying with a value of 20 and then experimenting). Try a range of values (even if it comes out as expected first try) to see what different latent space dimensions does to a VAE.

You will also want to define other important parameters such as ***batch size***, ***learning rate***, ***hidden dimensions*** and ***epochs***.

<strong style="color:red;">TODO 1.1: Set the hyperparameters for your model:</strong>

In [None]:
dataset_path = '~/datasets'
cuda = False
DEVICE = torch.device("cuda" if cuda else "cpu")

# ---------------------------------------------------------------------------
# TODO 1.1: Set hyperparameters
# ---------------------------------------------------------------------------

batch_size = ...

x_dim = ...
hidden_dim = ...
latent_dim = ...

lr = ...

epochs = ...

# ---------------------------------------------------------------------------
# END TODO 1.1
# ---------------------------------------------------------------------------

## Loading the data:

Since the MNIST dataset is so widely used, it has become one of the many `torchvision` datasets which can be loaded as a package. For this reason, we use the available packages to load our train and test datasets/loaders. This is code we provide for you, however you can also find documentation on other `torchvision.datasets` if you want to try out other cool projects using online data!

In [None]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

mnist_transform = transforms.Compose([transforms.ToTensor(), ])

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

train_dataset = MNIST(dataset_path, train=True, download=True, transform=mnist_transform)
test_dataset = MNIST(dataset_path, train=False, download=True, transform=mnist_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **kwargs)

## Understanding the encoder-decoder structure in a VAE

---

![VAE Diagram](./images/vae_diagram.jpg)

(this image comes from [pyimagesearch](https://pyimagesearch.com/2023/10/02/a-deep-dive-into-variational-autoencoders-with-pytorch/) A full explanation for how a vae works in depth can be found there too!)

### Encoder: Mapping inputs to a dsitribution

The encoder's role is to compress the input image into a representation in the latent space. But unlike a standard autoencoder, the VAE encoder outputs a **distribution**, not a single vector. For each input ***x***, the encoder produces two vectors:
- **mean vector** ($\mu$) - the center of the latent distribution
- **Log-variance vector** (log $\sigma^2$) - describes how spread out the distribution is.

Formally:

$(\mu, log(\sigma^2))$ = Encoder($x$)

This probabalistic encoding allows the latent space to capture both **what features are important** and **how uncertain the model is** about them.

<strong style="color:red;">TODO 2.1: Define the Encoder:</strong>

In [None]:
# ----------------------------------------------------------------------------
# TODO 2.1: Implement the Encoder
# ----------------------------------------------------------------------------

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        
        pass
        

    def forward(self, x):
        pass
        return z_mean, z_log_var

# ----------------------------------------------------------------------------
# END TODO 2.1
# ----------------------------------------------------------------------------


### Decoder: Reconstructing the input

The decoder takes a point from the latent space and attempts to reconstruct the original image.

$\hat x$ = Decoder($z$)

Here $\hat x$ is the reconstructed image, and its similarity to the original input $x$ is measured using the **reconstruction loss**. If the encoder has captures the right features, the decoder can recreate images that look very much like the originals.

<strong style="color:red;">TODO 2.2: Define the Decoder:</strong>


In [None]:
# ----------------------------------------------------------------------------
# TODO 2.2: Implement the Decoder
# ----------------------------------------------------------------------------

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        pass

    def forward(self, x):
        pass
        return x_hat

# ----------------------------------------------------------------------------
# END TODO 2.2
# ----------------------------------------------------------------------------

### Putting it together: the reparameterization trick

To connect the encoder and decoder, we need to sample a latent vector $z$ from the distribution produced by the encoder. However, naive sampling would break backpropogation. The solution sis the **reparameterization trick**:

$z = \mu + \sigma \odot \epsilon$

$\epsilon ~ \mathcal{N}(0, I)$

where: 
- $\mu$ is the mean from the encoder
- $\sigma$ = $exp(0.5 * log (\sigma^2))$ is the standard deviation
- $\epsilon$ is random noise drawn from a standard normal distribution
- $\odot$ denotes element-wise multiplication.

This formulation keeps the randomness while allowing gradients to flow, making the model trainable. By combining **probabalistic encoding**, **differentiable sapling**, and **decoding**, the VAE learns a **smooth and contineous latent space**. This not only enables faithful reconstruction of digits but also makes it possible to generate **entirely new samples** by drawing random vectors from the latent space.

<strong style="color:red;">TODO 2.3: Implement the full VAE:</strong>


In [None]:
# ----------------------------------------------------------------------------
# TODO 2.3: Implement the VAE
# ----------------------------------------------------------------------------

class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(VAE, self).__init__()
        
        # Define the encoder and decoder
        pass

    def reparameterization(self, mean, var):
        pass
        return z

    def forward(self, x):
        pass
        return x_hat, mean, log_var

# ----------------------------------------------------------------------------
# END TODO 2.3
# ----------------------------------------------------------------------------

In [None]:
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim=hidden_dim, output_dim=x_dim)

model = VAE(Encoder=encoder, Decoder=decoder).to(DEVICE)

## Setting up the optimizer and loss function

Now that the model is defined, we need two more ingredients before we can start training: 
1. A **loss function** that tells the model how well its doing.
2. An **optimizer** that updates the model's parameters based on that loss.

---

### The VAE loss function

The loss in the the VAE has two parts: 

1. **Reconstruction loss**:
    - Measures how close the reconstructed imahe $\hat x$ is to the original input image $x$.
    - Here, we use **binary cross-entropy (BCE)**, summed over all pixels.
    - Essentially the model gets penalized if it can't recreate the input digits correctly.

    $\mathrm{recon} = \mathrm{BCE}(x, \hat x)$

2. **KL divergence loss**:
    - Regularizes the latent space so that the encoded distribution $\mathcal{N}(\mu, \sigma^2)$ stay close to a standard normal distribution $\mathcal{N}(0, I)$. 
    - This keeps the latent space smooth and ensures that sampling random points produces meaningful digits

    $\mathrm{KL} = - 0.5 \sum(1 + log(\sigma^2) - \mu^2, - \sigma^2)$
    
    The final loss is the sum of both terms (with an optional scaling factor $\beta$ to control the strength of the KL term):

    $\mathcal{L} = \mathrm{recon} + \beta * KL$

    Dividing by the batch size helps keep th eloss values stable.

### The optimizer:

To train the model we use the **Adam optimizer**. Adam adapts the learning rate for each parameter, making training faster and more stable than just standard gradient descent. Therefore to make it run we pass in the `model parameters` and the `learning rate`.

<strong style="color:red;">TODO 3.1-3.2: Define the loss function and the optimizer:</strong>


In [None]:
from torch.optim import Adam
import torch.nn.functional as F

# ----------------------------------------------------------------------------
# TODO 3.1: Implement the loss function
# ----------------------------------------------------------------------------

# reconstruction + KL divergence losses summed over all elements and batch
def loss_function(x, x_hat, mean, log_var, beta=1.0):
    # reconstruction loss (BCE summed over pixels)
    pass

    # KL divergence term
    pass

    # normalize by batch size for stability
    return (recon + beta * kl) / x.size(0)

# ----------------------------------------------------------------------------
# END TODO 3.1
# ----------------------------------------------------------------------------

# ----------------------------------------------------------------------------
# TODO 3.2: Set up the optimizer
# ----------------------------------------------------------------------------

optimizer = pass

# ----------------------------------------------------------------------------
# END TODO 3.2
# ----------------------------------------------------------------------------


## The training loop:

With the model, loss, and optimizer defined, we can now put everything together in a training loop. Training a VAE looks similar to training other neural networks, but we track three key values:  

1. **Reconstruction loss** – measures how well the decoder reproduces the input images.  
2. **KL divergence** – regularizes the latent space so it stays close to a standard normal distribution.  
3. **Total loss** – the sum of the two, which the optimizer minimizes.  

At each epoch:  
- We loop through the training batches and pass the images through the encoder and decoder.  
- We compute the reconstruction and KL terms separately to better monitor how the model is learning.  
- The gradients are reset (`optimizer.zero_grad()`), the loss is backpropagated (`loss.backward()`), and the optimizer updates the model’s parameters (`optimizer.step()`).  

Using `tqdm`, we also show a progress bar with the current loss values for each batch, making it easier to see improvements during training. At the end of each epoch, we print the average losses so we can track the overall learning progress.  

By monitoring both reconstruction and KL divergence, we ensure the model is **balancing accurate reconstructions with a well-structured latent space**, which is the essence of training a VAE.  

<strong style="color:red;">TODO 4.1: Write and run the training loop:</strong>


In [None]:
from tqdm import tqdm

print("Start training VAE...")

# ----------------------------------------------------------------------------
# TODO 4.1: Implement the training loop
# ----------------------------------------------------------------------------

model.train()

for epoch in range(epochs):
    pass

# ----------------------------------------------------------------------------
# END TODO 4.1
# ----------------------------------------------------------------------------

print("Finish!!")


## Result visualization:

We provide a simple overview of your results. The top row shows the input image and the bottom row shows the images the network reconstructs. Feel free to try create other meaningful visualisations, as it tests your ability to check the performance of your models!

In [None]:
from additional.plots import plot_reconstructions

plot_reconstructions(model, test_loader, DEVICE)

## Solutions:

---

In [None]:
dataset_path = '~/datasets'
cuda = False
DEVICE = torch.device("cuda" if cuda else "cpu")

# ---------------------------------------------------------------------------
# TODO 1.1: Set hyperparameters
# ---------------------------------------------------------------------------

batch_size = 100

x_dim = 784
hidden_dim = 256
latent_dim = 10

lr = 1e-3

epochs = 50

# ---------------------------------------------------------------------------
# END TODO 1.1
# ---------------------------------------------------------------------------

In [None]:
# LOADING MODEL:

"""
simple Gaussian MLP Encoder and Decoder
"""
# ----------------------------------------------------------------------------
# TODO 2.1: Implement the Encoder
# ----------------------------------------------------------------------------

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.fc_input = nn.Linear(x_dim, hidden_dim)
        self.fc_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc2_log_var = nn.Linear(hidden_dim, latent_dim)

        self.ReLU = nn.ReLU(0.2)

        self.training = True

    def forward(self, x):
        h = self.ReLU(self.fc_input(x))
        h = self.ReLU(self.fc_input2(h))
        z_mean = self.fc2_mean(h)
        z_log_var = self.fc2_log_var(h)
        return z_mean, z_log_var

# ----------------------------------------------------------------------------
# END TODO 2.1
# ----------------------------------------------------------------------------


In [None]:
# ----------------------------------------------------------------------------
# TODO 2.2: Implement the Decoder
# ----------------------------------------------------------------------------

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        self.fc_hidden = nn.Linear(latent_dim, hidden_dim)
        self.fc_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_output = nn.Linear(hidden_dim, x_dim)

        self.ReLU= nn.ReLU(0.2)

        self.training = True

    def forward(self, x):
        h = self.ReLU(self.fc_hidden(x))
        h = self.ReLU(self.fc_hidden2(h))

        x_hat = torch.sigmoid(self.fc_output(h))
        return x_hat

# ----------------------------------------------------------------------------
# END TODO 2.2
# ----------------------------------------------------------------------------

In [None]:
# ----------------------------------------------------------------------------
# TODO 2.3: Implement the VAE
# ----------------------------------------------------------------------------

class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(VAE, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)
        z = mean + var * epsilon
        return z

    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))
        x_hat = self.Decoder(z)

        return x_hat, mean, log_var

# ----------------------------------------------------------------------------
# END TODO 2.3
# ----------------------------------------------------------------------------

In [None]:
from torch.optim import Adam
import torch.nn.functional as F

# ----------------------------------------------------------------------------
# TODO 3.1: Implement the loss function
# ----------------------------------------------------------------------------

# reconstruction + KL divergence losses summed over all elements and batch
def loss_function(x, x_hat, mean, log_var, beta=1.0):
    # reconstruction loss (BCE summed over pixels)
    recon = F.binary_cross_entropy(x_hat, x, reduction="sum")

    # KL divergence term
    kl = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

    # normalize by batch size for stability
    return (recon + beta * kl) / x.size(0)

# ----------------------------------------------------------------------------
# END TODO 3.1
# ----------------------------------------------------------------------------

# ----------------------------------------------------------------------------
# TODO 3.2: Set up the optimizer
# ----------------------------------------------------------------------------

optimizer = Adam(model.parameters(), lr=lr)

# ----------------------------------------------------------------------------
# END TODO 3.2
# ----------------------------------------------------------------------------


In [None]:
from tqdm import tqdm

print("Start training VAE...")

# ----------------------------------------------------------------------------
# TODO 4.1: Implement the training loop
# ----------------------------------------------------------------------------

model.train()

for epoch in range(epochs):
    overall_loss = 0
    overall_recon = 0
    overall_kl = 0
    
    loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}")
    
    for batch_idx, (x, _) in loop:
        x = x.view(batch_size, x_dim).to(DEVICE)
        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)

        # Compute separate components
        recon = F.binary_cross_entropy(x_hat, x, reduction="sum") / x.size(0)
        kl = (-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())) / x.size(0)
        loss = recon + kl  # optionally multiply kl by beta if needed: recon + beta * kl

        loss.backward()
        optimizer.step()

        overall_loss += loss.item()
        overall_recon += recon.item()
        overall_kl += kl.item()
        
        # Update tqdm with current batch components
        loop.set_postfix(loss=loss.item(), recon=recon.item(), kl=kl.item())
    
    avg_loss = overall_loss / len(train_loader)
    avg_recon = overall_recon / len(train_loader)
    avg_kl = overall_kl / len(train_loader)
    
    print(f"\tEpoch {epoch + 1} complete! "
          f"Average Loss: {avg_loss:.4f}, Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}")

# ----------------------------------------------------------------------------
# END TODO 4.1
# ----------------------------------------------------------------------------

print("Finish!!")
