# Variational Autoencoder (VAE) on CIFAR-10

This notebook implements a VAE and a β-VAE for CIFAR-10, generates samples, and visualizes latent interpolations.

## Setup

In [None]:
import math
import os
import random
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

# Reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


## Data: CIFAR-10

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


## Model: Encoder, Decoder, Reparameterization

In [None]:
@dataclass
class VAEConfig:
    latent_dim: int = 128
    hidden_channels: int = 64

class Encoder(nn.Module):
    def __init__(self, config: VAEConfig):
        super().__init__()
        c = config.hidden_channels
        self.conv = nn.Sequential(
            nn.Conv2d(3, c, 4, 2, 1),  # 32 -> 16
            nn.ReLU(inplace=True),
            nn.Conv2d(c, c * 2, 4, 2, 1),  # 16 -> 8
            nn.ReLU(inplace=True),
            nn.Conv2d(c * 2, c * 4, 4, 2, 1),  # 8 -> 4
            nn.ReLU(inplace=True),
        )
        self.fc_mu = nn.Linear(c * 4 * 4 * 4, config.latent_dim)
        self.fc_logvar = nn.Linear(c * 4 * 4 * 4, config.latent_dim)

    def forward(self, x):
        h = self.conv(x)
        h = h.view(h.size(0), -1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, config: VAEConfig):
        super().__init__()
        c = config.hidden_channels
        self.fc = nn.Linear(config.latent_dim, c * 4 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(c * 4, c * 2, 4, 2, 1),  # 4 -> 8
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(c * 2, c, 4, 2, 1),      # 8 -> 16
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(c, 3, 4, 2, 1),          # 16 -> 32
            nn.Sigmoid(),
        )

    def forward(self, z):
        h = self.fc(z)
        h = h.view(h.size(0), -1, 4, 4)
        x_recon = self.deconv(h)
        return x_recon

class VAE(nn.Module):
    def __init__(self, config: VAEConfig):
        super().__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar


## Loss Function (β-VAE)

In [None]:
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    recon_loss = F.mse_loss(recon_x, x, reduction="sum") / x.size(0)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + beta * kl_div, recon_loss, kl_div


## Training Utilities

In [None]:
@dataclass
class TrainConfig:
    epochs: int = 20
    lr: float = 2e-4
    beta: float = 1.0


def train_epoch(model, loader, optimizer, beta):
    model.train()
    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0
    for x, _ in loader:
        x = x.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(x)
        loss, recon_loss, kl_div = vae_loss(recon, x, mu, logvar, beta=beta)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl_div.item()
    n = len(loader)
    return total_loss / n, total_recon / n, total_kl / n


def eval_epoch(model, loader, beta):
    model.eval()
    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            recon, mu, logvar = model(x)
            loss, recon_loss, kl_div = vae_loss(recon, x, mu, logvar, beta=beta)
            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_kl += kl_div.item()
    n = len(loader)
    return total_loss / n, total_recon / n, total_kl / n


def train_vae(beta=1.0, epochs=20, lr=2e-4, latent_dim=128):
    config = VAEConfig(latent_dim=latent_dim)
    model = VAE(config).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    history = {"train": [], "val": []}
    for epoch in range(1, epochs + 1):
        train_metrics = train_epoch(model, train_loader, optimizer, beta)
        val_metrics = eval_epoch(model, test_loader, beta)
        history["train"].append(train_metrics)
        history["val"].append(val_metrics)
        print(
            f"Epoch {epoch:02d} | "
            f"train loss {train_metrics[0]:.4f} (recon {train_metrics[1]:.4f}, kl {train_metrics[2]:.4f}) | "
            f"val loss {val_metrics[0]:.4f} (recon {val_metrics[1]:.4f}, kl {val_metrics[2]:.4f})"
        )

    return model, history


## Train Baseline VAE (β=1)

In [None]:
baseline_model, baseline_history = train_vae(beta=1.0, epochs=20, lr=2e-4, latent_dim=128)


## Train β-VAE (β=5)

In [None]:
beta_model, beta_history = train_vae(beta=5.0, epochs=20, lr=2e-4, latent_dim=128)


## Plot Training Curves

In [None]:
def plot_history(history, title):
    train = np.array(history["train"])
    val = np.array(history["val"])
    epochs = np.arange(1, len(train) + 1)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train[:, 0], label="Train total")
    plt.plot(epochs, val[:, 0], label="Val total")
    plt.plot(epochs, train[:, 1], label="Train recon")
    plt.plot(epochs, val[:, 1], label="Val recon")
    plt.plot(epochs, train[:, 2], label="Train KL")
    plt.plot(epochs, val[:, 2], label="Val KL")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()

plot_history(baseline_history, "Baseline VAE (β=1)")
plot_history(beta_history, "β-VAE (β=5)")


## Sampling: 16 Generated Images

In [None]:
@torch.no_grad()
def sample_images(model, n=16, latent_dim=128):
    model.eval()
    z = torch.randn(n, latent_dim, device=device)
    samples = model.decoder(z).cpu()
    grid = utils.make_grid(samples, nrow=4)
    plt.figure(figsize=(5, 5))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis("off")
    plt.title("Samples from VAE")
    plt.show()

sample_images(baseline_model, n=16, latent_dim=128)


## Latent Space Interpolation

In [None]:
@torch.no_grad()
def interpolate(model, steps=10, latent_dim=128):
    model.eval()
    z1 = torch.randn(1, latent_dim, device=device)
    z2 = torch.randn(1, latent_dim, device=device)
    alphas = torch.linspace(0, 1, steps, device=device)
    zs = torch.cat([(1 - a) * z1 + a * z2 for a in alphas], dim=0)
    samples = model.decoder(zs).cpu()
    grid = utils.make_grid(samples, nrow=steps)
    plt.figure(figsize=(steps * 1.2, 2))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis("off")
    plt.title("Latent Interpolation")
    plt.show()

interpolate(baseline_model, steps=10, latent_dim=128)


## Brief Analysis (β=1 vs β=5)

Use this cell to summarize observations (≤ 300 words). For example:
- β=1 emphasizes reconstruction quality.
- β=5 enforces a more factorized latent space but can reduce image fidelity.

In [None]:
analysis_text = (
    "Write your summary here (<= 300 words). Discuss how increasing β affects sample sharpness, "
    "class consistency, and latent disentanglement (e.g., shape vs. color vs. orientation)."
)
print(analysis_text)
