In [21]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models

# Pre-processing

Load AlexNet as pretrained model

In [22]:
# Transformation steps for input data 
AlexTransform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3), # images need to be RGB, MNIST is in greyscale. Therefore needs to be converted to RGB
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(), # Also brings the tensor values in the range [0, 1] instead of [0, 255]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Path to mnist data
mnist_data_path = os.path.join("..", "..", "data", 'mnist-data')

# Create a loader for training data and testing data
batch_size = 500
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(mnist_data_path, train=True, download=True, transform=AlexTransform),
        batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(mnist_data_path, train=False, download=True, transform=AlexTransform),
        batch_size=1, shuffle=False)

val_loader = torch.utils.data.DataLoader(
        datasets.MNIST(mnist_data_path, train=False, download=True, transform=AlexTransform),
        batch_size=batch_size, shuffle=False)

In [23]:
# Check Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [24]:
net = models.alexnet(weights=models.AlexNet_Weights.DEFAULT).to(device)

print(net)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [25]:
# Change last layer to reflect MNIST categories
net_fc = net.classifier[-1] 
num_ftrs = net_fc.in_features # Get the amount of input features for the last layer

# Update to 10 possible outputs
net.classifier[-1] = nn.Linear(num_ftrs, 10)

# Finetuning AlexNet to MNIST

## Freezing all layers except fully connected

In [26]:
# Freeze all layers first
for param in net.parameters():
    param.requires_grad=False

# Unfreeze all fully connected layers
for layer in net.classifier:
    if isinstance(layer, torch.nn.Linear):  # Check if the layer is a fully connected layer
        for param in layer.parameters():
            param.requires_grad = True  # Unfreeze the fully connected layer

# Check the requires_grad status of each layer
for name, param in net.named_parameters():
    print(f"{name}: requires_grad = {param.requires_grad}")

features.0.weight: requires_grad = False
features.0.bias: requires_grad = False
features.3.weight: requires_grad = False
features.3.bias: requires_grad = False
features.6.weight: requires_grad = False
features.6.bias: requires_grad = False
features.8.weight: requires_grad = False
features.8.bias: requires_grad = False
features.10.weight: requires_grad = False
features.10.bias: requires_grad = False
classifier.1.weight: requires_grad = True
classifier.1.bias: requires_grad = True
classifier.4.weight: requires_grad = True
classifier.4.bias: requires_grad = True
classifier.6.weight: requires_grad = True
classifier.6.bias: requires_grad = True


## Training AlexNet on MNIST

In [27]:
# Define loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-3)


In [28]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 10 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [29]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, net.to(device), loss_fn, optimizer)
    test_loop(val_loader, net.to(device), loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.533980  [  500/60000]
loss: 2.057190  [ 5500/60000]
loss: 1.804513  [10500/60000]
loss: 1.598496  [15500/60000]
loss: 1.446943  [20500/60000]
loss: 1.383569  [25500/60000]
loss: 1.258295  [30500/60000]
loss: 1.241185  [35500/60000]
loss: 1.121786  [40500/60000]
loss: 1.019857  [45500/60000]
loss: 1.072538  [50500/60000]
loss: 0.974331  [55500/60000]
Test Error: 
 Accuracy: 83.8%, Avg loss: 0.816800 

Epoch 2
-------------------------------
loss: 0.984362  [  500/60000]
loss: 0.937802  [ 5500/60000]
loss: 0.951671  [10500/60000]


KeyboardInterrupt: 

# Save the model

In [23]:
torch.save(net.state_dict(), os.path.join("data", "AlexNet",'AlexNet_finetuned_MNIST.pth'))