# 👖 Autoencoders on Fashion MNIST


In this notebook, we'll walk through the steps required to train your own autoencoder on the fashion MNIST dataset.


In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets
from rich.progress import Progress
from torchvision.transforms import v2
from mpl_toolkits.axes_grid1 import make_axes_locatable

from notebooks.utils import display

## 0. Parameters


In [None]:
IMAGE_SIZE = 32
CHANNELS = 1
BATCH_SIZE = 100
BUFFER_SIZE = 1000
VALIDATION_SPLIT = 0.2
EMBEDDING_DIM = 2
EPOCHS = 5

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    gpu_name = torch.cuda.get_device_name(0)
    print(f"Let's use CUDA ({gpu_name})")
else:
    device = torch.device('cpu')

## 1. Prepare the data


In [None]:
# Load the data
transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Pad([2, 2, 2, 2], padding_mode='constant', fill=0.0),
    ]
)
train_set = datasets.FashionMNIST(root='data', train=True, download=True, transform=transforms)
test_set = datasets.FashionMNIST(root='data', train=False, download=True, transform=transforms)

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, persistent_workers=True
)

In [None]:
# Show some items of clothing from the training set
x_train = np.array([train_set[i][0] for i in range(10)])
display(x_train)

## 2. Build the autoencoder


In [None]:
# Encoder
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(CHANNELS, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, EMBEDDING_DIM),
        )

    def forward(self, x):
        x = self.net(x)
        return x

In [None]:
# Decoder
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(EMBEDDING_DIM, 128 * 4 * 4),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, CHANNELS, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.net(x)
        return x

In [None]:
encoder = Encoder().to(device)
decoder = Decoder().to(device)
train_parameters = list(encoder.parameters()) + list(decoder.parameters())
opt = torch.optim.Adam(train_parameters, lr=1.0e-3)

## 3. Train the autoencoder


In [None]:
with Progress() as progress:
    ema_loss = 0.0
    for epoch in range(EPOCHS):
        train_task = progress.add_task('Training...', total=len(train_loader))
        encoder.train()
        decoder.train()
        for X, _ in train_loader:
            X = X.to(device)
            z = encoder(X)
            X_hat = decoder(z)
            loss = F.binary_cross_entropy(X_hat, X)

            opt.zero_grad()
            loss.backward()
            opt.step()

            ema_loss = 0.9 * ema_loss + 0.1 * loss.item() if ema_loss else loss.item()
            progress.update(train_task, advance=1, description=f'[{epoch + 1}/{EPOCHS}] loss: {ema_loss:.4f}')

    progress.update(train_task, refresh=True)

## 4. Reconstruct using the autoencoder


In [None]:
n_to_predict = 5000
test_subset = torch.utils.data.Subset(test_set, np.arange(n_to_predict))
test_loader = torch.utils.data.DataLoader(
    test_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, persistent_workers=True
)

In [None]:
avg_loss = 0.0
count = 0
example_images = []
example_labels = []
predictions = []

with Progress() as progress:
    test_task = progress.add_task('Testing...', total=len(test_loader))
    encoder.eval()
    decoder.eval()
    for X, y in test_loader:
        with torch.no_grad():
            X = X.to(device)
            z = encoder(X)
            X_hat = decoder(z)

        loss = F.binary_cross_entropy(X_hat, X)
        avg_loss += loss.item() * X.size(0)
        count += X.size(0)

        example_images.append(X.cpu().detach().numpy())
        example_labels.append(y.cpu().detach().numpy())
        predictions.append(X_hat.cpu().detach().numpy())

        progress.update(test_task, advance=1)

    avg_loss /= count
    example_images = np.concatenate(example_images, axis=0)
    predictions = np.concatenate(predictions, axis=0)

    progress.update(test_task, refresh=True)
    progress.console.log(f'[Test] loss: {avg_loss:.4f}')

In [None]:
print("Example real clothing items")
display(example_images)
print("Reconstructions")
display(predictions)

## 5. Embed using the encoder


In [None]:
# Encode the example images
embeddings = []

with Progress() as progress:
    test_task = progress.add_task('Testing...', total=len(test_loader))
    encoder.eval()
    decoder.eval()
    for X, _ in test_loader:
        with torch.no_grad():
            X = X.to(device)
            z = encoder(X)

        embeddings.append(z.cpu().detach().numpy())
        progress.update(test_task, advance=1)

    embeddings = np.concatenate(embeddings, axis=0)
    progress.update(test_task, refresh=True)
    progress.console.log(f'[Test] loss: {avg_loss:.4f}')

In [None]:
# Some examples of the embeddings
print(embeddings[:10])

In [None]:
# Show the encoded points in 2D space
figsize = 8

plt.figure(figsize=(figsize, figsize))
plt.scatter(embeddings[:, 0], embeddings[:, 1], c='black', alpha=0.5, s=3)
plt.show()

In [None]:
# Colour the embeddings by their label (colothing type - see table)
figsize = 8

plt.figure(figsize=(figsize, figsize))
plt.scatter(
    embeddings[:, 0],
    embeddings[:, 1],
    cmap='rainbow',
    c=example_labels[:n_to_predict],
    alpha=0.8,
    s=3,
)
plt.colorbar()
plt.show()

## 6. Generate using the decoder


In [None]:
# Get the range of the existing embeddings
mins, maxs = np.min(embeddings, axis=0), np.max(embeddings, axis=0)

# Sample some points in the latent space
grid_width, grid_height = (6, 3)
sample = np.random.uniform(mins, maxs, size=(grid_width * grid_height, EMBEDDING_DIM))

In [None]:
# Decode the sampled points
sample_torch = torch.tensor(sample, dtype=torch.float32, device=device)
with torch.no_grad():
    reconstructions = decoder(sample_torch)

reconstructions = reconstructions.cpu().detach().numpy()

In [None]:
# Draw a plot of...
figsize = 8
plt.figure(figsize=(figsize, figsize))

# ... the original embeddings ...
plt.scatter(embeddings[:, 0], embeddings[:, 1], c='black', alpha=0.5, s=2)

# ... and the newly generated points in the latent space
plt.scatter(sample[:, 0], sample[:, 1], c='#00B0F0', alpha=1, s=40)
plt.show()

In [None]:
# Add underneath a grid of the decoded images
fig = plt.figure(figsize=(figsize, grid_height * 2))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.imshow(reconstructions[i].squeeze(), cmap='gray')
    ax.axis('off')
    ax.text(
        0.5,
        -0.35,
        str(np.round(sample[i, :], 1)),
        fontsize=10,
        ha='center',
        transform=ax.transAxes,
    )
    ax.imshow(reconstructions[i].squeeze(), cmap='Greys')

plt.show()

In [None]:
# Colour the embeddings by their label (clothing type - see table)
figsize = 12
grid_size = 15
fig, ax = plt.subplots(figsize=(figsize, figsize))
sc = ax.scatter(
    embeddings[:, 0],
    embeddings[:, 1],
    cmap='rainbow',
    c=example_labels,
    alpha=0.8,
    s=30,
    zorder=1,
)

# Add a colorbar in a new axis beside the main plot
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
fig.colorbar(sc, cax=cax)

# Show the reconstructions for grid points in the latent space
x = np.linspace(min(embeddings[:, 0]), max(embeddings[:, 0]), grid_size)
y = np.linspace(max(embeddings[:, 1]), min(embeddings[:, 1]), grid_size)
xv, yv = np.meshgrid(x, y)
xv = xv.flatten()
yv = yv.flatten()
grid = np.array(list(zip(xv, yv)))
with torch.no_grad():
    grid_torch = torch.tensor(grid, dtype=torch.float32, device=device)
    reconstructions = decoder(grid_torch)

alpha_channel = reconstructions.clone()
color_channels = 1.0 - reconstructions.repeat(1, 3, 1, 1)
reconstructions = torch.cat([color_channels, alpha_channel], dim=1)
reconstructions = reconstructions.permute(0, 2, 3, 1)
reconstructions = reconstructions.cpu().detach().numpy()

# Create a grid of locations where the reconstructions will be shown
xs = np.linspace(embeddings[:, 0].min(), embeddings[:, 0].max(), grid_size)
ys = np.linspace(embeddings[:, 1].max(), embeddings[:, 1].min(), grid_size)
dx = xs[1] - xs[0]
dy = ys[0] - ys[1]
xx, yy = np.meshgrid(xs, ys, indexing='xy')
grid = np.stack([xx.ravel(), yy.ravel()], axis=-1)

# fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_size**2):
    cx, cy = grid[i]
    ax.imshow(
        reconstructions[i].squeeze(),
        extent=[cx - dx / 2, cx + dx / 2, cy - dy / 2, cy + dy / 2],
        origin="upper",
        zorder=2,
    )
    ax.axis('off')

ax.set_xlim(xs[0] - dx / 2, xs[-1] + dx / 2)
ax.set_ylim(ys[-1] - dy / 2, ys[0] + dy / 2)
ax.set_aspect('equal')
fig.tight_layout()
plt.show()