# Exdercise: Transfer learning using MobileNetV3
Transfer learning is a powerful technique that leverages pre-trained models and applies them to new tasks. This approach allows us to save time and coputational resources by using the knowledge gained from training no large datasets.

This excercise will use MobileNetV3, a convolutional neural network architecture for mobile devices, to train a classifier for the Fashion-MNIST dataset using the PyTorch framework.

Fashion-MNIST is a drop-in replacement for MNIST (images of size 28x28 with 10 classes), but instead of digits it containes tiny images of clothes!

Task:
- Load the Fashion-MNIST dataset using the torchvision package.
- Define a PyTorch model using the MobileNetV3 architecture.
- Train the model on the Fashion-MNIST dataset.
- Evaluate the model on the test set.

# Step 1: Load the Fashion-MNIST dataset

In [None]:
# Load the Fashion-MNIST dataset

import torch
# may require pip install torchvision
import torchvision.datasets as datasets 
import torchvision.transforms as transforms

def load_data(batch_size, data_dir="data"):
    """Load the Fashion_MNIST dataset."""
    
    #Define transfors to normalize the data
    transform = transforms.Compose(
        [
            transforms.ToTensor(), #converts PIL image to tensor
            transforms.Normalize((0.5,),(0.5,)) # Normalizes the data to between 0 and 1
        ]
    )

    # Download and load the training data
    trainset = datasets.FashionMNIST(
        data_dir, download = True, train = True, transform = transform
    )
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size= batch_size, shuffle = True
    )

    # Download and load the test data
    testset = datasets.FashionMNIST(
        data_dir, download=True, train=False, transform=transform
    )
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=True
    )

    return trainloader, testloader

trainloader, testloader = load_data(64)

Sometimes it's useful to create functions that will help us work with the labels when they're more complicated that the hardwritte digits 0-9. Let's write those now

In [None]:
# Define some helper functions to help with the lables
def get_class_names():
    """Return th elist of classes in the Fashion_MNIST dataset."""
    return[
        "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
    ]

def get_class_name(class_index):
    """Return the class name for the given index."""
    return get_class_names()[class_index]

def get_class_index_for_name(class_name):
    """Return the class index for the given name."""
    return get_class_names().index(class_name)

for class_index in range(10):
    print(f"class_index: {class_index}, class_name: {get_class_name(class_index)}")

In [None]:
# show 10 images from the training set with their labels

import matplotlib.pyplot as plt
import numpy as np

# Function to show an image
def show_img(img):
    img = img / 2 + 0.4 # unnormalize
    npimg = img.numpy() # Convert from tensro to numpy array
    plt.imshow(np.transpose(npimg, (1,2,0))) # transpose dimensions to (height, width, channels)

images, labels = next(iter(trainloader)) # get the first batch

# Show images with labels
fig = plt.figure(figsize=(15,4))
plot_size = 10

for idx in np.arange(plot_size):
    ax = fig.add_subplot(2, plot_size // 2, idx + 1, xticks=[], yticks=[])
    show_img(images[idx])
    ax.set_title(get_class_name(int(labels[idx])))

plt.show() # show the figure

# Step 2: Define a PyTorch model using the MobileNetV3 architecture

The `torchvision.models.mobilenet_v3_large` class provides access to pretrianed MobileNetV3 model. We can use the mdoel and replace the final layer with a fully-connected layer with 10 outputs, since we have 10 classes. We can then freeze the weights of the convlutional layers and train only the new fully-connected layer.

Start with inspecting the original MobileNetV3 (small version) first:

In [None]:
# Load a pre-trained MobileNetV3 and inspect its structure
import torchvision.models as models

# Load the MobileNetV3 model
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
print(mobilenet_v3_small)

Note the `classifier` section of the model
(classifier): Sequential(
    (0): Linear(in_features=576, out_features=1024, bias=True)
    (1): Hardswish()
    (2): Dropout(p=0.2, inplace=True)
    (3): Linear(in_features=1024, **out_features=1000**, bias=True)
  )

  There are 1000 output features but our dataset does not have that many. We need to get the correct number of output nodes for our dataset.

In [None]:
import torch.nn.functional as F
import torchvision.models as models
from torch import nn

# Define a model class that extends the nn.Module class
class MobileNetV3(nn.Module):
    def __init__(self):
        super (MobileNetV3, self).__init__()

        # Load the pre-trained MobileNetV3 (small) architecture
        self.model = models.mobilenet_v3_small(pretrained=True)

        # Replace the last fully-connected layer with a new one of the right size
        self.model.classifier[3] = nn.Linear(1024, 10)

        # Freeze all the weights of the network except for the last fully-connected layer
        self.freeze()

    def forward(self, x):
        # Convert 1x28x28 input tensor to 3x28x28 tensor, to convert it to a color image
        x = x.repeat(1,3,1,1)

        # Resize the input to 244x244 since MobileNetV3 (Small) expects images of that size
        if x.shape[2:] != (244,244):
            x = F.interpolate(x, size=(244, 244), mode="bilinear", align_corners=False)

        # Forward pass
        return self.model(x)

    def freeze(self):
        # Freeze all the weights of the network except for the last fully-connected layer
        for param in self.model.parameters():
            param.requires_grad = False

        # Unfreeze the final layer
        for param in self.model.classifier[3].parameters():
            param.requires_grad = True

    def unfreeze(self):
        # Unfreeze all the weights of the network
        for param in self.model.parameters():
            param.requires_grad = True

# Create an instance of the MobileNetV3 model
model = MobileNetV3()
print(model)

# Step 3: Train the model on the MNIST dataset

We can train the model using the standard PyTorch training loop. For the loss function, we'll use CrossEntropyLoss. We also use the Adam optimizer with a learning rate of 0.002. We train the model for 1 epoch to see how th emodel performs after just one pass of the training data

In [None]:
import torch
import torch.nn as nn

# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)

Now choose our device automatically (CPU, GPU, OR MPS) and write the training loop.
The MPS backend is for M1/M2/ etc Macs.

If you're having trouble running the code locally, you can try using the `cpu` mode manually, ie: `device = torch.device("cpu")`

In [None]:
# Set the device as GPU, MPS, or CPU according to availability

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using devicd: {device}")

In [None]:
# Create a PyTorch training loop
model = model.to(device) # Move the model weights to device

epochs = 1

for epoch in range(epochs):
    for batch_num, (images, labels) in enumerate(trainloader):
        # Move the tensors to the device
        images = images.to(device)
        labels = labels.to(device)

        # Zero out the optimizer's gradient buffer
        optimizer.zero_grad()

        # Forward pass (predictions)
        pred_images = model(images)

        # Calculate the loss and perform back propagation
        loss = loss_fn(pred_images, labels)
        loss.backward()

        # Update the weights
        optimizer.step()

        # Print the loss for every 100th iteration
        if(batch_num) % 100 ==0:
            print(
                "Epoch [{}, {}], Batch [{},{}], Loss: {:.4f}".format(
                    epoch + 1, epochs, batch_num +1, len(trainloader), loss.item()
                )
            )

# Step 4: Evalute the model on the test set

We evaluate the model by printing the accuracy and plotting a few examples of correct and incorrect predictions

In [None]:
correct = 0
total = 0
loss = 0

for images, labels in testloader:
    # Move the tensors to the device
    images = images.to(device)
    labels = labels.to(device)

    # Forward pass
    outputs = model(images)
    loss += loss_fn(outputs, labels)

    # torch.max return both max and argmax. We get the argmax here.
    _, predicted = torch.max(outputs.data, 1)

    # Compute the accuracy
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(
    "Test Accuracy of the model on the test images: {} %".format(100 * correct / total)
)
print("Test Loss of the model on the test images: {}".format(loss))


In [None]:
# Plot a few examples of correct and incorrect predictions

import matplotlib.pyplot as plt
import numpy as np

# Get the first batch of images and labels
images, labels = next(iter(testloader))

# Move the tensors to the configured device
images = images.to(device)
labels = labels.to(device)

# Forward pass
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)

# Plot the images with labels, at most 10
fig = plt.figure(figsize=(15,4))

for idx in np.arange(min(10, len(images))):
    ax = fig.add_subplot(2, 10 // 2, idx + 1, xticks =[], yticks=[])
    ax.imshow(np.squeeze(images.cpu()[idx]))

    ax.set_title(
        "{} ({})".format(get_class_name(predicted[idx]), get_class_name(labels[idx])), color=("green" if predicted[idx] == labels[idx] else "red")
    )