In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import torch.nn.functional as F
import umap.umap_ as umap
from torch.utils.data import random_split

Import MNIST dataset and preprocess.

In [None]:
# Transform: normalize and convert to tensor
transform = transforms.Compose([
    transforms.ToTensor(),
])

os.makedirs("./data", exist_ok=True)

# Download training and test data
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size

train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_subset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Let's take a look at our training, validation and test datasets.

In [None]:
import numpy as np
from torchvision.utils import make_grid

# Check dataset sizes
print(f"Training set size: {len(train_subset)} images")
print(f"Validation set size: {len(val_subset)} images")
print(f"Test set size:     {len(test_dataset)} images")

# Look at a single sample
img, label = train_subset[0]
print(f"\nSample image shape: {img.shape}")
print(f"Sample label: {label}")

# Display a grid of samples
def show_batch(dataloader, n=32):
    imgs, labels = next(iter(dataloader))
    grid = make_grid(imgs[:n], nrow=8, padding=2)
    npimg = grid.numpy()
    plt.figure(figsize=(8, 8))
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap="gray")
    plt.title("Batch of MNIST digits")
    plt.axis("off")
    plt.show()

show_batch(train_loader)

Here we define our first architecture. It is a linear fully connected autoencoder.

In [None]:
class FC_Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 12),
            nn.BatchNorm1d(12),
            nn.ReLU(),
            nn.Linear(12, 10)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(10, 12),
            nn.BatchNorm1d(12),
            nn.ReLU(),
            nn.Linear(12, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28*28) # Flatten the image
        x = self.encoder(x)
        x = self.decoder(x)
        x = x.view(-1, 1, 28, 28) # Image is re-built
        return x

fc_autoencoder = FC_Autoencoder()

print('Number of parameters:', sum(p.numel() for p in fc_autoencoder.parameters() if p.requires_grad))

Here we define a function which will tell us epoch by epoch the performance of the AE during training. We also define a function for training our autoencoder, which can be used also for other architectures.

In [None]:
def cosine_similarity_batch(x, y):
    """Compute mean cosine similarity between two image batches."""
    x = x.view(x.size(0), -1)
    y = y.view(y.size(0), -1)
    sim = F.cosine_similarity(x, y, dim=1)
    return sim.mean().item()

def train_autoencoder(
    model, train_loader, val_loader, epochs=5, lr=1e-3, device="cpu"
):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_history = []
    train_cos_sim_history = []
    val_loss_history = []
    val_cos_sim_history = []

    for epoch in range(epochs):

        # ===== TRAIN =====
        model.train()
        running_loss = 0.0
        running_cos_sim = 0.0

        for imgs, _ in train_loader:
            imgs = imgs.to(device)
            outputs = model(imgs)

            loss = criterion(outputs, imgs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_cos_sim += cosine_similarity_batch(outputs.detach(), imgs)

        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_cos = running_cos_sim / len(train_loader)
        train_loss_history.append(epoch_train_loss)
        train_cos_sim_history.append(epoch_train_cos)

        # ===== VALIDATION =====
        model.eval()
        val_loss = 0.0
        val_cos_sim = 0.0

        with torch.no_grad():
            for imgs, _ in val_loader:
                imgs = imgs.to(device)
                outputs = model(imgs)

                loss = criterion(outputs, imgs)
                val_loss += loss.item()
                val_cos_sim += cosine_similarity_batch(outputs, imgs)

        epoch_val_loss = val_loss / len(val_loader)
        epoch_val_cos = val_cos_sim / len(val_loader)
        val_loss_history.append(epoch_val_loss)
        val_cos_sim_history.append(epoch_val_cos)

        print(
            f"Epoch [{epoch+1}/{epochs}] "
            f"- Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f} "
            f"- Train CosSim: {epoch_train_cos:.4f}, Val CosSim: {epoch_val_cos:.4f}"
        )

    # ===== PLOT =====
    fig, ax1 = plt.subplots()

    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss", color="tab:blue")
    ax1.plot(train_loss_history, label="Train Loss", color="tab:blue")
    ax1.plot(val_loss_history,   label="Val Loss",   color="tab:cyan")
    ax1.tick_params(axis="y", labelcolor="tab:blue")

    ax2 = ax1.twinx()
    ax2.set_ylabel("Cosine Similarity", color="tab:red")
    ax2.plot(train_cos_sim_history, label="Train CosSim", color="tab:red", linestyle="--")
    ax2.plot(val_cos_sim_history,   label="Val CosSim",   color="tab:pink", linestyle="--")
    ax2.tick_params(axis="y", labelcolor="tab:red")

    fig.tight_layout()
    plt.title("Autoencoder Training & Validation")
    fig.legend(loc="upper right")
    plt.show()

Let's train the autoencoder we previously built.

In [None]:
train_autoencoder(fc_autoencoder, train_loader, val_loader, epochs=5)

A key step of evaluating the performance of an Autoencoder is to visually check how the reconstructed images look like wrt to the input images. It can help us understanding where the AE fails, or what to change in the architecture to improve its performance.

In [None]:
def visualize_reconstructions(model, dataloader, n=10):
    model.eval()
    imgs, _ = next(iter(dataloader))
    imgs = imgs[:n]
    with torch.no_grad():
        recon = model(imgs)

    plt.figure(figsize=(15, 3))
    for i in range(n):
        # Original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(imgs[i].squeeze(), cmap="gray")
        plt.axis("off")
        # Reconstructed
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(recon[i].squeeze(), cmap="gray")
        plt.axis("off")
    plt.show()

visualize_reconstructions(fc_autoencoder, test_loader)

Here we define a function to visualize the latent space distribution. If the dimensionality of the latent space is 2, it directly plots the latent space distribution, otherwise if the dimensionality is > 2, we use UMAP to reduce it and visualize the results.

In [None]:
def plot_latent_space(model, dataloader, n_batches=50):
    model.eval()
    latents = []
    labels = []

    with torch.no_grad():
        for i, (imgs, lbls) in enumerate(dataloader):
            
            # Pass only through the encoder
            z = model.encoder(imgs.view(imgs.size(0), -1)) if imgs.ndim == 4 and isinstance(model, FC_Autoencoder) \
                else model.encode(imgs)
            latents.append(z.cpu())
            labels.append(lbls)
            
            if i >= n_batches:  # limit batches for speed
                break

    latents = torch.cat(latents)
    labels = torch.cat(labels)

    if len(latents[0]) == 2:
        additional = ''
        latents_ = latents
    else:
        reducer = umap.UMAP(n_neighbors=20, min_dist=0.05, n_components=2,
                    metric='euclidean', random_state=42)
        print(latents.shape)
        latents_ = reducer.fit_transform(latents)
        additional = ', with UMAP reduction'

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(latents_[:, 0], latents_[:, 1], c=labels, cmap="tab10", s=10, alpha=0.7)
    plt.colorbar(scatter, ticks=range(10))
    plt.title(f"2D Latent Space (MNIST) {additional}")
    plt.xlabel("z₁")
    plt.ylabel("z₂")
    plt.show()

# Example usage:
plot_latent_space(fc_autoencoder, test_loader)

Now, as we are dealing with images, let's try a new AE architecture, based on Convolutional Neural Networks.

In [None]:
class Conv_Autoencoder_2D(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),   # -> (16, 14, 14)
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # -> (32, 7, 7)
            nn.ReLU(),
        )
        # Flatten -> 2D latent vector
        self.fc1 = nn.Linear(32 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 10)

        # Decoder: 2D -> (32, 7, 7)
        self.fc3 = nn.Linear(10, 256)
        self.fc4 = nn.Linear(256, 32 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),  # -> (16, 14, 14)
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),   # -> (1, 28, 28)
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        z = self.fc2(x) # Final latent space
        x = self.fc3(z)
        x = self.fc4(x)
        x = x.view(-1, 32, 7, 7)
        x = self.decoder(x)
        return x

    # helper for encoding only
    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        z = self.fc2(x)
        return z

conv_autoencoder_2d = Conv_Autoencoder_2D()

print('Number of parameters:', sum(p.numel() for p in conv_autoencoder_2d.parameters() if p.requires_grad))

Let's train the autoencoder with the function we defined before...

In [None]:
train_autoencoder(conv_autoencoder_2d, train_loader, val_loader, epochs=5)

...and visualize the reconstruction performance.

In [None]:
visualize_reconstructions(conv_autoencoder_2d, test_loader, n=15)

Finally let's take a look at the latent space distribution.

In [None]:
plot_latent_space(conv_autoencoder_2d, test_loader)

Can we use the latent space to create new images?

In [None]:

def show_random_latent_samples(model, n_samples=5, latent_dim=10):
    model.eval()

    # Sample from a normal distribution
    z = torch.randn(n_samples, latent_dim)

    # Decode to images
    with torch.no_grad():
        x_decoded = model.decoder(
            model.fc4(
                model.fc3(z)
            ).view(-1, 32, 7, 7)
        )

    # Plot results
    plt.figure(figsize=(n_samples*2, 2))
    for i in range(n_samples):
        plt.subplot(1, n_samples, i+1)
        plt.imshow(x_decoded[i].squeeze().cpu(), cmap="gray")
        plt.axis("off")
    plt.suptitle("Random samples from latent space")
    plt.show()

# use it
show_random_latent_samples(conv_autoencoder_2d, n_samples=10)