# Unsupervised learning using the MNIST dataset

This notebook explores the use of an __autoencoder__ neural network.

An autoencoder features two modules in sequence: and __encoder__ and a __decoder__. The encoder maps an input layer to a smaller layer, called the __bottleneck__. Them, the decoder maps the bottleneck into an output layer.

The training adjusts both encoder and decoder weights such that the output is the same as the input for the training set. This might look useless, but the relevant aspect is that the model is trained to encode the input to a bottleneck, such that it is sufficient to faithfullt reconstruct the input values in the output.

## Import packages

We'll use
* time -- a standard Python library providing time related functions
* matplotlib for showing images -- more info in https://matplotlib.org/
* PyTorch is a machine learning library -- more info in https://pytorch.org/
  * torch.nn for using neural networks
  * torch.utils.data for loading datasets
  * torchvision.datasets for access to the EuroSAT dataset
  * torchvision.transforms for data transformation among images and tensors

In [None]:
%matplotlib widget

import time
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import *
from torchvision import datasets
from torchvision.transforms import *

Determination of the AI acceleration device, if any.

In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

## Dataset loading

Next we load the MNIST dataset. We'll use the whole dataset.

In [None]:
full_dataset = datasets.MNIST(
    root="data",
    download=True,
    transform=ToTensor(),
)

In [None]:
# Define batch size
batch_size = 64

# Create data loader
dataloader = DataLoader(full_dataset, batch_size=batch_size)

# Show the shape of one instance -- we'll only use the image for training
for X, y in dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Next we show some examples of the dataset.

In [None]:
(fig,ax) = plt.subplots(1, 10, constrained_layout=True)
ds = full_dataset
for j in range(10):
    ax[j].imshow(ds[j][0][0], cmap="gray")
    ax[j].set_axis_off()
    ax[j].set_title(ds[j][1])
plt.show()

## Create a neural network model

Here we define some models to play with.

In [None]:
class ClassicAutoencoderMNIST(nn.Module):
    bottleneck = 10
    
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.unflatten = nn.Unflatten(1, (1,28,28))
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, self.bottleneck)
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.bottleneck, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.flatten(x)
        coded = self.encoder(x)
        return coded

    def decode(self, x):
        x = self.decoder(x)
        decoded = self.unflatten(x)
        return decoded

    def forward(self, x):
        return self.decode( self.encode(x) )

In [None]:
class ConvAutoencoderMNIST(nn.Module):
    bottleneck = 10
    
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(32*7*7, 128),
            nn.ReLU(),
            nn.Linear(128, self.bottleneck)
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.bottleneck, 128),
            nn.ReLU(),
            nn.Linear(128, 32*7*7),
            nn.ReLU(),
            nn.Unflatten(1, (32,7,7)),
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=2, stride=2),
            nn.Sigmoid() # Output should be in the range [0, 1]
        )

    def encode(self, x):
        coded = self.encoder(x)
        return coded

    def decode(self, x):
        decoded = self.decoder(x)
        return decoded

    def forward(self, x):
        return self.decode( self.encode(x) )

Selection of the model, the loss and the optimier.

In [None]:
# Uncomment the model you want to try out, leaving all others commented out
model = ClassicAutoencoderMNIST().to(device)
#model = ConvAutoencoderMNIST().to(device)

print(model)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## Training the model

First we define a train function. Differently than previously, the model is trained to minimize the difference (error) between the output and input, over the training set. An encoding will automatically emerge.

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    t = 0
    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, X)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if time.monotonic()-t > 1:
            t = time.monotonic()
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

The next cell performs the actual training.

In [None]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(dataloader, model, loss_fn, optimizer)
print("Done!")

The following two cells can be used to save the model to disk or load it from disk. __Skip__ these two cells, unless you wish to save and/or load models to/from the disk.

In [None]:
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

In [None]:
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth", weights_only=True))

## Testing the model

The next cell shows the encoding and decoding for several instances of the dataset.
* first row: input
* second row: encoded input at the bottleneck
* third row: output, reconstructing the input

Note how well the input is reconstructed, even though the encoded data is much smaller than the input size.

In [None]:
model.eval()
(fig,ax) = plt.subplots(3, 10, constrained_layout=True)
ds = full_dataset
for j in range(10):
    x, _ = ds[j]
    x = x[None,:].to(device)
    with torch.no_grad():
        z = model.encode(x)
        y = model.decode(z)
    
    ax[0,j].imshow(ds[j][0][0], cmap="gray")
    ax[0,j].set_axis_off()
    ax[1,j].imshow(z[None,0].cpu(), cmap="gray")
    ax[1,j].set_axis_off()
    ax[2,j].imshow(y[0][0].cpu(), cmap="gray")
    ax[2,j].set_axis_off()
    
plt.axis('off')
plt.show()

An autoencoder can be seen as a rudimentary __generative model__, in the sense that the decoder generates instances similar to the ones in the training set.

To demonstrate this feature, we show next the output of randomly generated values in the bottleneck. Althought the output images are not quite handwritten digits, they somehow resemble them.

The first row shows the values at the bottleneck, while the second row shows the corresponding generated outputs.

In [None]:
model.eval()
(fig,ax) = plt.subplots(2, 10, constrained_layout=True)
for j in range(10):
    z = torch.randn(10)
    z = z[None,:].to(device)

    with torch.no_grad():
        y = model.decode(z)
    
    ax[0,j].imshow(z[None,0].cpu(), cmap="gray")
    ax[0,j].set_axis_off()
    ax[1,j].imshow(y[0][0].cpu(), cmap="gray")
    ax[1,j].set_axis_off()
    
plt.axis('off')
plt.show()

Now let's play a little bit with the decoder.

First we compute the mean and covariance of the encoded values at the bottleneck, over the dataset. Intuitively, the mean is the "center of mass" of the encoded values, while the covariance represents the dispersion of those values.

Then we perform the spectral decomposition (i.e., compute the eigenvalues and eigenvectors) of the covariance matrix. This is effectively the same as the Principal Component Analysis (PCA) of the encoded values. The principal components is a set of orthogonal directions over which the data is dispersed. This allows the identification of the main directions over which the data varies around its "center of mass".

In [None]:
all_data = torch.stack([x for (x,_) in full_dataset])
all_data = all_data.to(device)

model.eval()
with torch.no_grad():
    out = model.encode(all_data)
    out = out.cpu()

# determine mean and covariance of encoded data
u = torch.mean(out,0)
S = torch.cov(out.T)

# perform spectral decompositon of the covariance matrix, i.e., Principal Component Analysis (PCA)
(L,Q)=torch.linalg.eigh(S)

Here we take the two principal components with highest dispersion, and decode the values along these twp components, along rows and columns below.

In [None]:
model.eval()
(fig,ax) = plt.subplots(11, 11, constrained_layout=True)
for i in range(11):
    for j in range(11):
        z = u + (i-5)*torch.sqrt(L[-2])*Q[:,-2] \
              + (j-5)*torch.sqrt(L[-1])*Q[:,-1]
        z = z[None,:].to(device)
    
        with torch.no_grad():
            y = model.decode(z)
        
        ax[i,j].imshow(y[0][0].cpu(), cmap="gray")
        ax[i,j].set_axis_off()
    
plt.axis('off')
plt.show()

Now let's range over each one of the 10 highest principal components, one at a time, and see what happens.

In [None]:
# generate some images along the 10 highest principal values

model.eval()
(fig,ax) = plt.subplots(10, 11, constrained_layout=True)
for i in range(10):
    for j in range(11):
        z = u + (j-5)*torch.sqrt(L[-1-i])*Q[:,-1-i]
        z = z[None,:].to(device)
    
        with torch.no_grad():
            y = model.decode(z)
        
        ax[i,j].imshow(y[0][0].cpu(), cmap="gray")
        ax[i,j].set_axis_off()
    
plt.axis('off')
plt.show()