# Multi-Layer Perceptron on MNIST

In this unit practical, we will build and evaluate a Multi-Layer Perceptron (MLP) that classifies images from the [MNIST](http://yann.lecun.com/exdb/mnist/) handwritten digit dataset.

The notebook is organised into the following steps:
1. **Data Loading and Visualization**
2. **Define the Network Architecture**
3. **Specify Loss Function and Optimizer**
4. **Train the Network**
5. **Evaluate the Model**
6. **Visualize Predictions and Learned Features (t-SNE)**

Let's begin by importing the necessary libraries.

In [None]:
# Imports
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
%matplotlib inline

## Step 1: Data Loading and Visualization

We load the MNIST dataset (both training and test sets) and create data loaders. You can change the `batch_size` or `num_workers` as needed.

In [None]:
from torchvision import datasets
import torchvision.transforms as transforms

# Parameters for data loader
num_workers = 0
batch_size = 16

# Define transformation
transform = transforms.ToTensor()

# Load datasets
train_data = datasets.MNIST(root='DATA', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='DATA', train=False, download=True, transform=transform)

# Create dataloaders
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers)
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)

### Visualize a Batch of Training Data

It is always a good idea to inspect your data. The following cell visualizes a single batch of training images along with their corresponding labels.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# Get one batch of training images
dataiter = iter(train_loader)
images, labels = next(dataiter)
images = images.numpy()

# Plot the batch of training images with labels
fig = plt.figure(figsize=(25, 4))
for idx in np.arange(batch_size):
    ax = fig.add_subplot(2, batch_size//2, idx+1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(images[idx]), cmap='gray')
    ax.set_title(str(labels[idx].item()))

## Step 2: Define the Network Architecture

We create an MLP with one hidden layer. The network takes a 784-dimensional flattened image as input and outputs a 10-dimensional tensor, representing class scores for each digit. A sigmoid activation is used after the first fully connected layer.

In [None]:
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # First fully connected layer: from 784 (28x28) to 512 neurons
        self.fc1 = [TODO]
        # Second fully connected layer: from 512 to 10 neurons (one per class)
        self.fc2 = [TODO]

    def forward(self, x):
        # Flatten the image input
        x = x.view(-1, 28 * 28)
        # Apply the first FC layer and sigmoid activation
        x = [TODO]
        
        # Apply the second FC layer (output layer)
        x = [TODO]
        
        return x

# Initialize the network and print its architecture
model = Net()
print(model)

## Step 3: Specify Loss Function and Optimizer

We use the cross-entropy loss for this classification task. The optimizer is set to SGD with a learning rate of 0.01.

In [None]:
# Specify loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

## Step 4: Train the Network

We now train the network. The following loop runs for a specified number of epochs (adjustable as needed). For each epoch, we clear the gradients, perform a forward pass, compute the loss, backpropagate the error, and update the model parameters. The average training loss for each epoch is printed.

In [None]:
# Number of epochs for training
n_epochs = 5  # For a faster demo; consider using between 20-50 epochs for real training

model.train()  # Set the model to training mode

for epoch in range(n_epochs):
    train_loss = 0.0
    
    # Train the model on each batch
    for data, target in train_loader:
        optimizer.zero_grad()         # Clear gradients
        output = model(data)            # Forward pass
        loss = criterion(output, target)  # Calculate loss
        loss.backward()                 # Backward pass
        optimizer.step()                # Update parameters
        train_loss += loss.item() * data.size(0)  # Accumulate loss
    
    # Calculate average loss over the epoch
    train_loss = train_loss / len(train_loader.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch+1, train_loss))

## Step 5: Evaluate the Trained Network

After training, we evaluate our network on the test data. The test loop calculates the average test loss and computes per-class as well as overall accuracy. Note that `model.eval()` disables dropout and uses running statistics for batch normalization during testing.

In [None]:
# Initialize tracking variables for test loss and accuracy per class
test_loss = 0.0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

model.eval()  # Set model to evaluation mode
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)            # Forward pass
        loss = criterion(output, target)  # Calculate loss
        test_loss += loss.item() * data.size(0)
        
        _, pred = torch.max(output, 1)    # Get predictions
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        for i in range(batch_size):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

# Calculate average loss
test_loss = test_loss / len(test_loader.dataset)
print('Test Loss: {:.6f}\n'.format(test_loss))

# Print test accuracy for each class
for i in range(10):
    if class_total[i] > 0:
        print('Test Accuracy of {:5s}: {:2d}% ({:2d}/{:2d})'.format(
            str(i),
            int(100 * class_correct[i] / class_total[i]),
            int(class_correct[i]), int(class_total[i])))
    else:
        print('Test Accuracy of {:5s}: N/A (no training examples)'.format(str(i)))

print('\nTest Accuracy (Overall): {:2d}% ({:2d}/{:2d})'.format(
    int(100. * np.sum(class_correct) / np.sum(class_total)),
    int(np.sum(class_correct)), int(np.sum(class_total))))

### Visualize Sample Test Predictions

The following cell displays a batch of test images and shows the predicted label alongside the true label. Correct predictions are highlighted in green, whereas incorrect ones are shown in red.

In [None]:
# Get a batch of test images
dataiter = iter(test_loader)
images, labels = next(dataiter)

# Get predictions
output = model(images)
_, preds = torch.max(output, 1)
images = images.cpu().numpy()

# Plot the batch of test images with predicted and true labels
fig = plt.figure(figsize=(25, 4))
for idx in np.arange(batch_size):
    ax = fig.add_subplot(2, batch_size//2, idx+1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(images[idx]), cmap='gray')
    ax.set_title("{} ({})".format(str(preds[idx].item()), str(labels[idx].item())),
                 color=("green" if preds[idx]==labels[idx] else "red"))

## Step 6: t-SNE Visualization

In this section, we perform t-SNE on:
1. **Raw Pixel Inputs:** We project the flattened pixel values into 2D space.
2. **Learned Features:** We extract the features from the last layer (after `fc2`) using a forward hook and then perform t-SNE on these features.

In [None]:
from sklearn.manifold import TSNE
import seaborn as sns

# Prepare the raw pixel data
X = torch.stack([img for img, _ in test_data])
X_np = X.view(-1, 28*28).numpy()
Y = torch.tensor([label for _, label in test_data]).numpy()

# Define and apply t-SNE on raw pixel inputs
tsne = TSNE(n_components=2, perplexity=30, learning_rate='auto', init='pca', n_iter=300)
X_embeded = tsne.fit_transform(X_np)

# Plot the t-SNE results
plt.figure(figsize=(10,8))
sns.scatterplot(x=X_embeded[:, 0], y=X_embeded[:, 1], hue=Y, palette="deep")
plt.title("t-SNE on MNIST Raw Pixels")

### t-SNE on Learned Features (from fc2 layer)

We now extract features from the `fc2` layer using a hook, and then visualize these features with t-SNE.

In [None]:
extracted_features = []

def hook_function(module, input, output):
    extracted_features.append(output.clone().detach())

# Register the hook on the fc2 layer
hook = model.fc2.register_forward_hook(hook_function)

# Run a forward pass to trigger the hook
model_output = model(X)
X_fc2 = extracted_features[0]  # Extracted features
print("The data after fc2 has the shape: ", X_fc2.shape)

# Apply t-SNE on the fc2 features
X_embeded_fc2 = tsne.fit_transform(X_fc2)

# Plot the t-SNE results for fc2 features
plt.figure(figsize=(10,8))
sns.scatterplot(x=X_embeded_fc2[:, 0], y=X_embeded_fc2[:, 1], hue=Y, palette="deep")
plt.title("t-SNE on MNIST Features from fc2")

The learned features are clearly more separated then the raw pixel features, especially for digits 4 and 9, which enabled the network to make a better prediction!