## Variational Autoencoder (VAE) on MNIST (PyTorch)

Train a simple VAE to reconstruct and sample MNIST digits. Configure hyperparameters below, then run cells in order.


In [16]:
# Setup and imports
import os
import math
from dataclasses import dataclass

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils

In [17]:
# Setup and imports
import os
import math
from dataclasses import dataclass

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils

# Configuration
@dataclass
class Config:
    data_dir: str = os.path.expanduser("~/.data/mnist")
    batch_size: int = 128
    epochs: int = 5
    lr: float = 1e-3
    latent_dim: int = 20
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 2
    seed: int = 42
    output_dir: str = "./vae_outputs"

cfg = Config()
os.makedirs(cfg.output_dir, exist_ok=True)

torch.manual_seed(cfg.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.seed)

print(f"Using device: {cfg.device}")
print(cfg)


Using device: cuda
Config(data_dir='/home/ubuntu/.data/mnist', batch_size=128, epochs=5, lr=0.001, latent_dim=20, device='cuda', num_workers=2, seed=42, output_dir='./vae_outputs')


In [18]:
# Data: MNIST loaders
transform = transforms.Compose([
    transforms.ToTensor(),  # converts to [0,1]
])

train_ds = datasets.MNIST(root=cfg.data_dir, train=True, transform=transform, download=True)
val_ds = datasets.MNIST(root=cfg.data_dir, train=False, transform=transform, download=True)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)

len_train = len(train_loader)
len_val = len(val_loader)
print(f"Train batches: {len_train}, Val batches: {len_val}")


Train batches: 469, Val batches: 79


In [19]:
# VAE model
class Encoder(nn.Module):
    def __init__(self, latent_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 400),
            nn.ReLU(True),
        )
        self.mu = nn.Linear(400, latent_dim)
        self.logvar = nn.Linear(400, latent_dim)

    def forward(self, x):
        h = self.net(x)
        return self.mu(h), self.logvar(h)

class Decoder(nn.Module):
    def __init__(self, latent_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(True),
            nn.Linear(400, 28 * 28),
            nn.Sigmoid(),  # output in [0,1]
        )

    def forward(self, z):
        x_hat = self.net(z)
        return x_hat.view(-1, 1, 28, 28)

class VAE(nn.Module):
    def __init__(self, latent_dim: int):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decoder(z)
        return x_hat, mu, logvar

model = VAE(cfg.latent_dim).to(cfg.device)
print(model)


VAE(
  (encoder): Encoder(
    (net): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=784, out_features=400, bias=True)
      (2): ReLU(inplace=True)
    )
    (mu): Linear(in_features=400, out_features=20, bias=True)
    (logvar): Linear(in_features=400, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (net): Sequential(
      (0): Linear(in_features=20, out_features=400, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=400, out_features=784, bias=True)
      (3): Sigmoid()
    )
  )
)


In [20]:
# Loss and utilities

def vae_loss(x, x_hat, mu, logvar):
    # Reconstruction loss: binary cross entropy over pixels
    recon = F.binary_cross_entropy(x_hat, x, reduction='sum')
    # KL divergence between q(z|x) and p(z) ~ N(0, I)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + kld, recon, kld

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_recon = 0.0
    total_kld = 0.0
    total_elems = 0
    for x, _ in loader:
        x = x.to(device)
        x_hat, mu, logvar = model(x)
        loss, recon, kld = vae_loss(x, x_hat, mu, logvar)
        total_loss += loss.item()
        total_recon += recon.item()
        total_kld += kld.item()
        total_elems += x.size(0)
    return {
        'loss_per_img': total_loss / total_elems,
        'recon_per_img': total_recon / total_elems,
        'kld_per_img': total_kld / total_elems,
    }

def save_image_grid(tensor, path, nrow=8):
    vutils.save_image(tensor, path, nrow=nrow, padding=2, normalize=True)
    print(f"Saved: {path}")


In [21]:
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

fixed_noise = torch.randn(64, cfg.latent_dim, device=cfg.device)

best_val = math.inf
for epoch in range(1, cfg.epochs + 1):
    model.train()
    running_loss = 0.0
    running_recon = 0.0
    running_kld = 0.0
    count = 0

    for x, _ in train_loader:
        x = x.to(cfg.device)
        x_hat, mu, logvar = model(x)
        loss, recon, kld = vae_loss(x, x_hat, mu, logvar)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_recon += recon.item()
        running_kld += kld.item()
        count += x.size(0)

    train_loss = running_loss / count
    train_recon = running_recon / count
    train_kld = running_kld / count

    val_metrics = evaluate(model, val_loader, cfg.device)

    print(f"Epoch {epoch:03d} | train loss {train_loss:.2f} (recon {train_recon:.2f}, kld {train_kld:.2f}) | "
          f"val loss {val_metrics['loss_per_img']:.2f}")

    # Reconstructions
    model.eval()
    with torch.no_grad():
        x_sample, _ = next(iter(val_loader))
        x_sample = x_sample[:64].to(cfg.device)
        x_hat, _, _ = model(x_sample)
        grid = torch.cat([x_sample, x_hat], dim=0)
        save_image_grid(grid, os.path.join(cfg.output_dir, f"recon_epoch_{epoch:03d}.png"), nrow=8)

        # Random samples from prior
        samples = model.decoder(fixed_noise)
        save_image_grid(samples, os.path.join(cfg.output_dir, f"samples_epoch_{epoch:03d}.png"), nrow=8)

    # Track best
    if val_metrics['loss_per_img'] < best_val:
        best_val = val_metrics['loss_per_img']
        torch.save(model.state_dict(), os.path.join(cfg.output_dir, "vae_best.pt"))
        print("Saved best model.")


Epoch 001 | train loss 163.83 (recon 148.29, kld 15.54) | val loss 126.62
Saved: ./vae_outputs/recon_epoch_001.png
Saved: ./vae_outputs/samples_epoch_001.png
Saved best model.
Epoch 002 | train loss 120.86 (recon 98.54, kld 22.33) | val loss 115.44
Saved: ./vae_outputs/recon_epoch_002.png
Saved: ./vae_outputs/samples_epoch_002.png
Saved best model.
Epoch 003 | train loss 114.12 (recon 90.16, kld 23.96) | val loss 111.60
Saved: ./vae_outputs/recon_epoch_003.png
Saved: ./vae_outputs/samples_epoch_003.png
Saved best model.
Epoch 004 | train loss 111.33 (recon 86.86, kld 24.48) | val loss 109.45
Saved: ./vae_outputs/recon_epoch_004.png
Saved: ./vae_outputs/samples_epoch_004.png
Saved best model.
Epoch 005 | train loss 109.67 (recon 84.90, kld 24.77) | val loss 108.33
Saved: ./vae_outputs/recon_epoch_005.png
Saved: ./vae_outputs/samples_epoch_005.png
Saved best model.


In [22]:
# Convergence settings (KL warmup, early stopping, scheduler)
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Increase epochs if needed
cfg.epochs = max(cfg.epochs, 20)
cfg.patience = 5
cfg.min_delta = 1e-3
cfg.kl_warmup_epochs = 10
cfg.beta_start = 0.0
cfg.beta_end = 1.0
cfg.min_lr = 1e-5

print("Updated config for convergence:", cfg)


Updated config for convergence: Config(data_dir='/home/ubuntu/.data/mnist', batch_size=128, epochs=20, lr=0.001, latent_dim=20, device='cuda', num_workers=2, seed=42, output_dir='./vae_outputs')


In [23]:
# Override loss with KL warmup and helpers

def compute_beta(epoch: int, step: int, steps_per_epoch: int) -> float:
    total_warmup_steps = max(1, cfg.kl_warmup_epochs * steps_per_epoch)
    current_step = (epoch - 1) * steps_per_epoch + step
    t = min(1.0, current_step / total_warmup_steps)
    return float(cfg.beta_start + t * (cfg.beta_end - cfg.beta_start))


def vae_loss(x, x_hat, mu, logvar, beta: float = 1.0):
    recon = F.binary_cross_entropy(x_hat, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + beta * kld, recon, kld


@torch.no_grad()
def evaluate(model, loader, device, beta: float = 1.0):
    model.eval()
    total_loss = 0.0
    total_recon = 0.0
    total_kld = 0.0
    total_elems = 0
    for x, _ in loader:
        x = x.to(device)
        x_hat, mu, logvar = model(x)
        loss, recon, kld = vae_loss(x, x_hat, mu, logvar, beta=beta)
        total_loss += loss.item()
        total_recon += recon.item()
        total_kld += kld.item()
        total_elems += x.size(0)
    return {
        'loss_per_img': total_loss / total_elems,
        'recon_per_img': total_recon / total_elems,
        'kld_per_img': total_kld / total_elems,
    }


In [None]:
# Training loop with KL warmup, early stopping, LR scheduling, and robust checkpointing
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=2,
    threshold=cfg.min_delta,
    min_lr=cfg.min_lr,
    # verbose argument removed to fix TypeError
)

fixed_noise = torch.randn(64, cfg.latent_dim, device=cfg.device)

best_val = math.inf
epochs_no_improve = 0


def save_checkpoint(path: str):
    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'cfg': cfg.__dict__}, path)
    print(f"Saved: {path}")


try:
    for epoch in range(1, cfg.epochs + 1):
        model.train()
        running_loss = 0.0
        running_recon = 0.0
        running_kld = 0.0
        count = 0

        for step, (x, _) in enumerate(train_loader, start=1):
            x = x.to(cfg.device)
            x_hat, mu, logvar = model(x)
            beta = compute_beta(epoch, step, len_train)
            loss, recon, kld = vae_loss(x, x_hat, mu, logvar, beta=beta)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()
            running_recon += recon.item()
            running_kld += kld.item()
            count += x.size(0)

        train_loss = running_loss / count
        train_recon = running_recon / count
        train_kld = running_kld / count

        val_metrics = evaluate(model, val_loader, cfg.device, beta=cfg.beta_end)
        scheduler.step(val_metrics['loss_per_img'])

        print(
            f"Epoch {epoch:03d} | beta ~ {compute_beta(epoch, 1, len_train):.2f}->{compute_beta(epoch, len_train, len_train):.2f} | "
            f"train {train_loss:.2f} (recon {train_recon:.2f}, kld {train_kld:.2f}) | "
            f"val {val_metrics['loss_per_img']:.2f} | lr {optimizer.param_groups[0]['lr']:.2e}"
        )

        # Reconstructions and samples
        model.eval()
        with torch.no_grad():
            x_sample, _ = next(iter(val_loader))
            x_sample = x_sample[:64].to(cfg.device)
            x_hat, _, _ = model(x_sample)
            grid = torch.cat([x_sample, x_hat], dim=0)
            save_image_grid(grid, os.path.join(cfg.output_dir, f"recon_epoch_{epoch:03d}.png"), nrow=8)

            samples = model.decoder(fixed_noise)
            save_image_grid(samples, os.path.join(cfg.output_dir, f"samples_epoch_{epoch:03d}.png"), nrow=8)

        # Save last checkpoint each epoch
        save_checkpoint(os.path.join(cfg.output_dir, "vae_last.pt"))

        # Early stopping and best model
        if val_metrics['loss_per_img'] + cfg.min_delta < best_val:
            best_val = val_metrics['loss_per_img']
            epochs_no_improve = 0
            torch.save(model.state_dict(), os.path.join(cfg.output_dir, "vae_best.pt"))
            print("Saved best model.")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= cfg.patience:
                print(f"Early stopping triggered after no improvement for {cfg.patience} epochs.")
                break

except RuntimeError as e:
    if "interrupted by user" in str(e).lower():
        print("Training interrupted by Jupyter kernel interrupt.")
    else:
        raise
finally:
    # Ensure last weights are saved even if interrupted
    torch.save(model.state_dict(), os.path.join(cfg.output_dir, "vae_last_weights_only.pt"))
    save_checkpoint(os.path.join(cfg.output_dir, "vae_last.pt"))


Epoch 001 | beta ~ 0.00->0.10 | train 71.95 (recon 69.11, kld 60.45) | val 120.71 | lr 1.00e-03
Saved: ./vae_outputs/recon_epoch_001.png
Saved: ./vae_outputs/samples_epoch_001.png
Saved: ./vae_outputs/vae_last.pt
Saved best model.
Epoch 002 | beta ~ 0.10->0.20 | train 76.61 (recon 69.48, kld 47.92) | val 114.45 | lr 1.00e-03
Saved: ./vae_outputs/recon_epoch_002.png
Saved: ./vae_outputs/samples_epoch_002.png
Saved: ./vae_outputs/vae_last.pt
Saved best model.
Epoch 003 | beta ~ 0.20->0.30 | train 80.78 (recon 70.23, kld 42.31) | val 110.91 | lr 1.00e-03
Saved: ./vae_outputs/recon_epoch_003.png
Saved: ./vae_outputs/samples_epoch_003.png
Saved: ./vae_outputs/vae_last.pt
Saved best model.
Epoch 004 | beta ~ 0.30->0.40 | train 84.60 (recon 71.12, kld 38.58) | val 108.12 | lr 1.00e-03
Saved: ./vae_outputs/recon_epoch_004.png
Saved: ./vae_outputs/samples_epoch_004.png
Saved: ./vae_outputs/vae_last.pt
Saved best model.
Epoch 005 | beta ~ 0.40->0.50 | train 88.13 (recon 72.07, kld 35.74) | val 1

In [None]:
print("Done")

In [None]:
# Visualization
import matplotlib.pyplot as plt

@torch.no_grad()
def show_grid(img_path, title=None):
    img = plt.imread(img_path)
    plt.figure(figsize=(6,6))
    plt.imshow(img)
    plt.axis('off')
    if title:
        plt.title(title)
    plt.show()

# Display last saved grids if they exist
recon_path = os.path.join(cfg.output_dir, f"recon_epoch_{cfg.epochs:03d}.png")
samples_path = os.path.join(cfg.output_dir, f"samples_epoch_{cfg.epochs:03d}.png")
if os.path.exists(recon_path):
    show_grid(recon_path, title="Reconstructions (top: input, bottom: reconstruction)")
if os.path.exists(samples_path):
    show_grid(samples_path, title="Random samples from N(0, I)")


In [25]:
import torch

if torch.cuda.is_available():
    print(f"CUDA is available. Number of devices: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print("CUDA is not available.")


CUDA is available. Number of devices: 1
Current device: 0
Device name: NVIDIA A100-SXM4-40GB
