In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import torch.nn.functional as F
import umap.umap_ as umap
from torch.utils.data import random_split

Import MNIST dataset and preprocess.

In [None]:
# Transform: normalize and convert to tensor
transform = transforms.Compose([
    transforms.ToTensor(),
])

os.makedirs("./data", exist_ok=True)

# Download training and test data
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=False, transform=transform)

train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size

train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_subset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Let's take a look at our training, validation and test datasets.

In [None]:
import numpy as np
from torchvision.utils import make_grid

# Check dataset sizes
print(f"Training set size: {len(train_subset)} images")
print(f"Validation set size: {len(val_subset)} images")
print(f"Test set size:     {len(test_dataset)} images")

# Look at a single sample
img, label = train_subset[0]
print(f"\nSample image shape: {img.shape}")
print(f"Sample label: {label}")

# Display a grid of samples
def show_batch(dataloader, n=32):
    imgs, labels = next(iter(dataloader))
    grid = make_grid(imgs[:n], nrow=8, padding=2)
    npimg = grid.numpy()
    plt.figure(figsize=(8, 8))
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap="gray")
    plt.title("Batch of MNIST digits")
    plt.axis("off")
    plt.show()

show_batch(train_loader)

Here we define a new model, based on a Variational Autoencodder architecture. 

In [None]:
class Conv_VAE_2D(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),   # -> (16, 14, 14)
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # -> (32, 7, 7)
            nn.ReLU(),
        )
        self.fc1 = nn.Linear(32*7*7, 256)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)
        
        # Decoder
        self.fc2 = nn.Linear(latent_dim, 256)
        self.fc_dec = nn.Linear(256, 32*7*7)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),  # -> (16, 14, 14)
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),   # -> (1, 28, 28)
            nn.Sigmoid()
        )
        
    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        x = self.fc2(z)
        x = self.fc_dec(x)
        x = x.view(-1, 32, 7, 7)
        x = self.decoder(x)
        return x
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar
    
    def vae_loss(self, recon_x, x, mu, logvar, beta=1):
        # Reconstruction (sum and normalize by batch)
        recon_loss = F.mse_loss(recon_x, x, reduction='sum') / x.size(0)
        # KL divergence (analytical)
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl /= x.size(0)  # normalize per batch

        return recon_loss + beta * kl, recon_loss, kl

conv_vae_2d = Conv_VAE_2D()

This function is similar to the one we defined for the standard Autoencoder, with some modifications to account for the variational structure.

In [None]:
def cosine_similarity_batch(x, y):
    """Compute mean cosine similarity between two image batches."""
    x = x.view(x.size(0), -1)
    y = y.view(y.size(0), -1)
    sim = F.cosine_similarity(x, y, dim=1)
    return sim.mean().item()

def train_vae(
    model,
    train_loader,
    val_loader,
    epochs=10,
    lr=1e-3,
    beta=1.0,
    device="cpu"
):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_history = []
    train_cos_history = []
    val_loss_history = []
    val_cos_history = []

    for epoch in range(1, epochs + 1):
        # ---- TRAIN ----
        model.train()
        running_loss = 0.0
        running_cos = 0.0
        n_batches = 0

        for imgs, _ in train_loader:
            imgs = imgs.to(device)
            recon, mu, logvar = model(imgs)
            loss, recon_loss, kl_loss = model.vae_loss(recon, imgs, mu, logvar, beta)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            # compute cosine on detached tensors to avoid extra grad graph
            running_cos += cosine_similarity_batch(recon.detach(), imgs.detach())
            n_batches += 1

        epoch_train_loss = running_loss / n_batches
        epoch_train_cos = running_cos / n_batches
        train_loss_history.append(epoch_train_loss)
        train_cos_history.append(epoch_train_cos)

        # ---- VALIDATION ----
        model.eval()
        running_val_loss = 0.0
        running_val_cos = 0.0
        n_val_batches = 0

        with torch.no_grad():
            for imgs, _ in val_loader:
                imgs = imgs.to(device)
                recon, mu, logvar = model(imgs)
                loss, _, _ = model.vae_loss(recon, imgs, mu, logvar, beta)

                running_val_loss += loss.item()
                running_val_cos += cosine_similarity_batch(recon, imgs)
                n_val_batches += 1

        epoch_val_loss = running_val_loss / n_val_batches
        epoch_val_cos = running_val_cos / n_val_batches
        val_loss_history.append(epoch_val_loss)
        val_cos_history.append(epoch_val_cos)

        print(
            f"Epoch {epoch}/{epochs}: "
            f"TrainLoss={epoch_train_loss:.4f}, TrainCos={epoch_train_cos:.4f} | "
            f"ValLoss={epoch_val_loss:.4f}, ValCos={epoch_val_cos:.4f}"
        )

    # ---- Plot results ----
    epochs_range = range(1, epochs + 1)
    fig, ax1 = plt.subplots(figsize=(8,4))

    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss", color="tab:blue")
    ax1.plot(epochs_range, train_loss_history, label="Train Loss", color="tab:blue")
    ax1.plot(epochs_range, val_loss_history,   label="Val Loss",   color="tab:cyan")
    ax1.tick_params(axis='y', labelcolor='tab:blue')

    ax2 = ax1.twinx()
    ax2.set_ylabel("Cosine Similarity", color="tab:red")
    ax2.plot(epochs_range, train_cos_history, label="Train CosSim", color="tab:red", linestyle='--')
    ax2.plot(epochs_range, val_cos_history,   label="Val CosSim",   color="tab:pink", linestyle='--')
    ax2.tick_params(axis='y', labelcolor='tab:red')

    # Put a combined legend
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines + lines2, labels + labels2, loc="upper right")

    plt.title("VAE: Loss & Cosine Similarity per Epoch")
    plt.tight_layout()
    plt.show()

Let's train the variational autoencoder.

In [None]:
train_vae(conv_vae_2d, train_loader, val_loader, beta = 0.1, epochs=5)

Let's visualize the reconstructed images.

In [None]:
def visualize_reconstructions_vae(model, dataloader, n=10):
    model.eval()
    imgs, _ = next(iter(dataloader))
    imgs = imgs[:n]
    with torch.no_grad():
        recon,_,_ = model(imgs)

    plt.figure(figsize=(15, 3))
    for i in range(n):
        # Original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(imgs[i].squeeze(), cmap="gray")
        plt.axis("off")
        # Reconstructed
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(recon[i].squeeze(), cmap="gray")
        plt.axis("off")
    plt.show()

visualize_reconstructions_vae(conv_vae_2d, test_loader)

Here we plot the final latent space distribution.

In [None]:
def plot_latent_space_vae(model, dataloader, n_batches=10):
    model.eval()
    latents = []
    labels = []

    with torch.no_grad():
        for i, (imgs, lbls) in enumerate(dataloader):
            
            # Pass only through the encoder
            z, _ = model.encode(imgs)
            latents.append(z.cpu())
            labels.append(lbls)
            
            if i >= n_batches:  # limit batches for speed
                break

    latents = torch.cat(latents)
    labels = torch.cat(labels)

    if len(latents[0]) == 2:
        additional = ''
        latents_ = latents
    else:
        reducer = umap.UMAP(n_neighbors=20, min_dist=0.05, n_components=2,
                    metric='euclidean', random_state=42)
        print(latents.shape)
        latents_ = reducer.fit_transform(latents)
        additional = ', with UMAP reduction'

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(latents_[:, 0], latents_[:, 1], c=labels, cmap="tab10", s=10, alpha=0.7)
    plt.colorbar(scatter, ticks=range(10))
    plt.title(f"2D Latent Space (MNIST) {additional}")
    plt.xlabel("z₁")
    plt.ylabel("z₂")
    plt.show()

# Example usage:
plot_latent_space_vae(conv_vae_2d, test_loader)

In [None]:
def sample_from_posterior_estimate(model, dataloader, n_samples=5):
    model.eval()

    imgs, _ = next(iter(dataloader))
    with torch.no_grad():
        mu, logvar = model.encode(imgs)
    
    # Mean of posterior approx
    mean = mu.mean(dim=0)
    std = mu.std(dim=0)

    z = torch.randn(n_samples, model.latent_dim) * std + mean

    with torch.no_grad():
        decoded = model.decode(z)

    plt.figure(figsize=(n_samples * 2, 2))
    for i in range(n_samples):
        plt.subplot(1, n_samples, i+1)
        plt.imshow(decoded[i].squeeze().cpu(), cmap="gray")
        plt.axis("off")
    plt.suptitle("Samples near posterior latent distribution")
    plt.show()
    
sample_from_posterior_estimate(conv_vae_2d, test_loader, n_samples=10)