## Autoencoder on MNIST dataset

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import numpy as np

### prepare (download) the MNIST dataset

Torchvision provides many datasets. Here we work with MNIST, a dataset of images of digits 0-9. 
 
See:
 
 1. Datasets:  https://pytorch.org/vision/main/datasets.html
 2. Transform:    https://pytorch.org/vision/0.9/transforms.html

In [None]:
# normalize the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
])

# load the MNIST dataset, without normalization
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# data for testing 
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# labels
print ("\nLabels:", train_dataset.classes)

# digits of the images
print ("\nClasses:", train_dataset.targets)

# shape of the training data tensor
print ("\nData shape:", train_dataset.data.shape)

# shape of test data tensor
print ("\nData shape:", test_dataset.data.shape)

### visualize images according to labels

See:
1. torch.where: 
    https://docs.pytorch.org/docs/stable/generated/torch.where.html    

In [None]:
from torch.utils.data import Subset

# number of images to show for each digit
num_examples = 8

# loop over labels
for label in range(10):
    
    # select indices with the matching label
    indices = torch.where(train_dataset.targets == label)[0]
    
    # define a dataset only with images with matching label
    label_dataset = Subset(train_dataset, indices)

    # define a dataloader for this (sub)-dataset
    label_loader = DataLoader(label_dataset, batch_size=num_examples, shuffle=True)
    
    # get some images
    for data in label_loader:
        img, labels = data 
        break
        
    # labels of all images should be the same    
    print (labels)
    
    plt.figure(figsize=(5, 2))
    for i in range(num_examples):
        plt.subplot(1, num_examples, i + 1)
        plt.imshow(img[i].numpy().reshape(28,28), cmap='gray')
        plt.axis('off')

### Define the autoencoder model


In [None]:
class Autoencoder(nn.Module):
    
    def __init__(self, encoding_dim):
        super(Autoencoder, self).__init__()
        # encoder
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(True),
            nn.Linear(128, encoding_dim),
        )
        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(encoding_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 28*28),
            # range of sigmoid matches the range of data
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
# Function for testing the autoencoder

def test_autoencoder(model):   
    
    with torch.no_grad():
         # get image from test_dataset
        for data in test_loader:
            img, target = data
            
            # change image of size 28x28 to 1d vector 
            img = img.view(img.size(0), -1)
            break
            
    plt.figure(figsize=(20, 4))

    # visualize some examples
    num_examples = 10    
    output = model(img).detach().numpy()
    
    for i in range(num_examples):
        # original image
        plt.subplot(2, num_examples, i + 1)
        plt.imshow(img[i].cpu().numpy().reshape(28, 28), cmap='gray')
        plt.title("Original")
        plt.axis('off')

        # veconstructed image
        plt.subplot(2, num_examples, i + 1 + num_examples)
        plt.imshow(output[i].reshape(28, 28), cmap='gray')
        plt.title("Reconstructed")
        plt.axis('off')

    plt.figure(figsize=(8, 8))        
    # visualize embeddings
    num_examples = 1000
    # embedding
    z = model.encoder(img).detach().numpy()  
    print(z.shape)
    digits = range(10)    
    for i in digits:
         indices = np.where((target == i))
         plt.scatter(z[indices,0],z[indices,1], s=0.8)
         plt.legend(digits)
    plt.show()    

### define the model

First, train the model with encoding_dim=2. Then, set encoding_dim=10 and rerun the training. See whether the result improves.

In [None]:
# dimension of latent space 
encoding_dim = 2
model = Autoencoder(encoding_dim)

### Let's test the model before training

In [None]:

test_loader = DataLoader(test_dataset, batch_size=10000, shuffle=False)

test_autoencoder(model)

### train autoencoder with reconstruction loss

Mean Square Error (MSE) loss, see:
   https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html

In [None]:
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.005)

num_epochs = 10

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

for epoch in range(num_epochs):
    for data in train_loader:
        
        # In this test, we do not need information of labels
        img, target = data

        # flatten each image to a 1D vector of length 28*28     
        img = img.view(img.size(0), -1)

        # Forward pass
        output = model(img)
        loss = criterion(output, img)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

### test the model again after training

In [None]:
test_autoencoder(model)