# Generative AI Exercise: Variational Autoencoder (VAE) on MNIST

In this exercise, you will explore the implementation and training of a Variational Autoencoder (VAE) on the MNIST dataset. The goal is to understand the fundamental principles of generative modeling and latent space representation using a VAE. 

### Objectives:
1. **Build a VAE**:
   - Define the encoder and decoder modules.
   - Train the VAE using the MNIST dataset with a customizable beta hyperparameter to control the KL divergence.

2. **Latent Space Exploration**:
   - Visualize and interpolate between digits in the latent space.
   - Generate smooth transitions between samples using latent space interpolation.

3. **Conditional VAE**:
   - Extend the VAE to a Conditional VAE (CVAE) by incorporating class labels.
   - Generate digits by conditioning on specific labels.

4. **Image Generation**:
   - Generate new MNIST-like samples by sampling from the latent space of the VAE and CVAE.

By the end of this exercise, you will gain hands-on experience in implementing generative models and applying them to meaningful tasks like image synthesis and conditional generation.


In [None]:
# Import the necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import random

In [None]:
# Download and preprocess the MNIST dataset
def load_mnist(batch_size=128):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
    ])
    # Download MNIST dataset
    dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
    
    # Train/Test split
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

train_loader, test_loader = load_mnist()


In [None]:
# Define the VAE model with encoder and decoder modules
class VAE(nn.Module):
    def __init__(self, latent_dim=2):
        super(VAE, self).__init__()
        
        # Encoder: Input -> Latent Parameters
        self.encoder = ...
        self.mu_layer = ...  # Mean
        self.log_var_layer = ...  # Log variance
        
        # Decoder: Latent -> Output
        self.decoder = ... # remember the output is normalized between [-1,1]
        
    def encode(self, x):
        hidden = self.encoder(x)
        mu = self.mu_layer(hidden)
        log_var = self.log_var_layer(hidden)
        return mu, log_var
    
    def decode(self, z):
        return self.decoder(z).view(-1, 1, 28, 28)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        recon = self.decode(z)
        return recon, mu, log_var


In [None]:
# Define the loss function
def vae_loss(recon, x, mu, log_var, beta=1):
    recon_loss = nn.functional.binary_cross_entropy(recon, x, reduction="sum")
    kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + beta * kl_divergence

# Train the VAE
def train_vae(model, train_loader, num_epochs=10, beta=1, lr=1e-3, device="cuda"):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for x, _ in train_loader:
            x = x.to(device)
            recon, mu, log_var = model(x)
            loss = vae_loss(recon, x, mu, log_var, beta=beta)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss / len(train_loader.dataset):.4f}")
return model 

vae = VAE(latent_dim=2)
model = train_vae(vae, train_loader, num_epochs=10, beta=1) #Explore the effect of beta!
torch.save(vae.state_dict(), "vae_mnist.pth")



In [None]:
# Function for latent space interpolation
def interpolate_latent(model, x1, x2, alpha, device="cuda"):
    model.to(device)
    model.eval()
    x1, x2 = x1.to(device), x2.to(device)
    
    with torch.no_grad():
        mu1, _ = model.encode(x1)
        mu2, _ = model.encode(x2)
        
        # Linear interpolation
        z = alpha * mu1 + (1 - alpha) * mu2
        recon = model.decode(z)
    
    return recon

vae = VAE(latent_dim=2)
vae.load_state_dict(torch.load("vae_mnist.pth"))

# Sample two random data points (x1, x2) from the training dataset
x1, _ = random.choice(train_loader.dataset)
x2, _ = random.choice(train_loader.dataset)

# Add batch dimension and move to the appropriate device
x1 = x1.unsqueeze(0).to("cuda")  # Assuming the model is on CUDA
x2 = x2.unsqueeze(0).to("cuda")

# Sample alpha from a uniform distribution [0, 1]
alpha = torch.rand(1).item()  # Random scalar between 0 and 1

# Interpolate in the latent space
interpolated_image = interpolate_latent(vae, x1, x2, alpha, device="cuda")

# Visualize the interpolated image
plt.imshow(interpolated_image.squeeze().cpu().numpy(), cmap="gray")
plt.title(f"Interpolated Image (alpha={alpha:.2f})")
plt.axis("off")
plt.show()

In [None]:
def plot_latent_grid(model, latent_bounds, grid_size, device="cuda"):
    """
    Plots a 2D grid of digits decoded from the latent space.

    Args:
        model (VAE): The trained VAE model.
        latent_bounds (list): Bounds of the latent space as [min, max] for each dimension.
        grid_size (int): Number of points to discretize the grid along each axis.
        device (str): Device for computation (e.g., "cuda" or "cpu").
    """
    model.to(device)
    model.eval()
    
    # Define the grid points
    grid_x = torch.linspace(latent_bounds[0], latent_bounds[1], grid_size)
    grid_y = torch.linspace(latent_bounds[0], latent_bounds[1], grid_size)
    grid_points = torch.cartesian_prod(grid_x, grid_y).to(device)
    
    # Decode the latent points
    with torch.no_grad():
        decoded_images = model.decode(grid_points).cpu().numpy()
    
    return decoded_images

vae = VAE(latent_dim=2)
vae.load_state_dict(torch.load("vae_mnist.pth"))    
grid_size=10
decoded_images = plot_latent_grid(vae, latent_bounds=[-2, 2], grid_size=grid_size)

# Reshape and plot the grid
fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(decoded_images[i].squeeze(), cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()

In [None]:
# Define the Conditional VAE model
class ConditionalVAE(VAE):
    def __init__(self, latent_dim=2, num_classes=10):
        super(ConditionalVAE, self).__init__(latent_dim)
        self.label_embedding = nn.Embedding(num_classes, 10)
        self.encoder = ...
        self.mu_layer = ...  # Mean
        self.log_var_layer = ...  # Log variance
        
        self.decoder = ...
    
    def encode(self, x, labels):
        labels_embedded = self.label_embedding(labels)
        x = torch.cat([x.view(x.size(0), -1), labels_embedded], dim=1)
        hidden = self.encoder(x)
        mu = self.mu_layer(hidden)
        log_var = self.log_var_layer(hidden)
        return mu, log_var
    
    def decode(self, z, labels):
        labels_embedded = self.label_embedding(labels)
        z = torch.cat([z, labels_embedded], dim=1)
        return self.decoder(z).view(-1, 1, 28, 28)
    
    def forward(self, x, labels):
        mu, log_var = self.encode(x, labels)
        z = self.reparameterize(mu, log_var)
        recon = self.decode(z, labels)
        return recon, mu, log_var


In [None]:
#Train the conditional model
...

In [None]:
# Function to generate samples from Conditional VAE
def generate_samples(model, label, num_samples, device="cuda"):
    model.to(device)
    model.eval()
    
    with torch.no_grad():
        z = torch.randn(num_samples, model.latent_dim).to(device)
        labels = torch.full((num_samples,), label, dtype=torch.long).to(device)
        generated = model.decode(z, labels)
    
    return [generated[i].cpu().numpy().squeeze() for i in range(num_samples)]

# Example: Generate samples of digit "3"
cond_vae = ConditionalVAE(latent_dim=2)
samples = generate_samples(cond_vae, label=3, num_samples=5)

# Visualize the generated samples
for i, img in enumerate(samples):
    plt.subplot(1, len(samples), i + 1)
    plt.imshow(img, cmap="gray")
    plt.axis("off")
plt.show()


### Evaluating the Quality of Trained Generative Models

To assess the quality of a trained generative model, we use metrics that compare the distribution of generated samples with the real data distribution. In this exercise, we use the **Maximum Mean Discrepancy (MMD)**, a statistical measure to quantify the difference between two distributions. Below, we describe two evaluation methods:

---

#### **1. Latent Space Interpolation Evaluation**

Using the `interpolate_latent` function, we generate new digits by:
1. Sampling two random data points (\(x_1\) and \(x_2\)) from the training dataset.
2. Interpolating between these points in the latent space using a randomly sampled \( \alpha \in [0, 1] \).
3. Decoding the interpolated latent representations to create new digit images.

We then calculate the MMD between these generated samples and the test dataset to evaluate how well the interpolated samples align with the real data distribution.

---

#### **2. Label-Conditioned Sample Evaluation**

Using the `generate_samples` function, we generate new samples conditioned on a specific label. For each label:
1. Generate a set of synthetic samples using the trained model.
2. Compare the generated samples with the test dataset samples that share the same label using MMD.

This approach evaluates the model's ability to generate realistic and diverse samples for specific classes, ensuring the model captures the conditional data distribution effectively.

---

#### **Why MMD?**

The Maximum Mean Discrepancy (MMD) measures the distance between two distributions by comparing their representations in a reproducing kernel Hilbert space (RKHS). It is defined as:

$$
\mathrm{MMD}^2(p, q) = \mathbb{E}_{x, x' \sim p} [k(x, x')] + \mathbb{E}_{y, y' \sim q} [k(y, y')] - 2 \mathbb{E}_{x \sim p, y \sim q} [k(x, y)],
$$

where \( k \) is a kernel function (e.g., Gaussian kernel). MMD is particularly well-suited for generative model evaluation as it provides a numerical measure of how closely the generated and real data distributions align.

---

In [None]:
def get_samples_by_label(loader, label):
    """
    Selects all samples of a specific label from the dataset.

    Args:
        loader (DataLoader): The DataLoader containing the dataset.
        label (int): The label for which to retrieve samples (0-9).

    Returns:
        torch.Tensor: Tensor of images corresponding to the specified label.
    """
    images = []

    # Iterate through the DataLoader
    for x, y in loader:
        # Find indices where the label matches
        indices = (y == label).nonzero(as_tuple=True)[0]
        if len(indices) > 0:
            # Append the matching images
            images.append(x[indices])

    # Concatenate all matching images into a single tensor
    if images:
        return torch.cat(images, dim=0)
    else:
        return torch.empty(0)  # Return an empty tensor if no matches are found

# Example usage:
# Assuming `test_loader` is the test DataLoader
label = 3
samples = get_samples_by_label(test_loader, label)
print(f"Number of samples with label {label}: {samples.shape[0]}")

# Visualize a few examples
if samples.shape[0] > 0:
    for i in range(min(5, samples.shape[0])):
        plt.subplot(1, 5, i + 1)
        plt.imshow(samples[i].squeeze().numpy(), cmap="gray")
        plt.axis("off")
    plt.show()


From [this tutorial](https://jejjohnson.github.io/research_journal/appendix/similarity/mmd/) on the MMD

In [None]:
def MMD(x, y, kernel):
    """Emprical maximum mean discrepancy. The lower the result
       the more evidence that distributions are the same.

    Args:
        x: first sample, distribution P
        y: second sample, distribution Q
        kernel: kernel type such as "multiscale" or "rbf"
    """
    xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))

    dxx = rx.t() + rx - 2. * xx # Used for A in (1)
    dyy = ry.t() + ry - 2. * yy # Used for B in (1)
    dxy = rx.t() + ry - 2. * zz # Used for C in (1)

    XX, YY, XY = (torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device))

    if kernel == "multiscale":

        bandwidth_range = [0.2, 0.5, 0.9, 1.3]
        for a in bandwidth_range:
            XX += a**2 * (a**2 + dxx)**-1
            YY += a**2 * (a**2 + dyy)**-1
            XY += a**2 * (a**2 + dxy)**-1

    if kernel == "rbf":

        bandwidth_range = [10, 15, 20, 50]
        for a in bandwidth_range:
            XX += torch.exp(-0.5*dxx/a)
            YY += torch.exp(-0.5*dyy/a)
            XY += torch.exp(-0.5*dxy/a)



    return torch.mean(XX + YY - 2. * XY)