<a href="https://colab.research.google.com/github/isa-ulisboa/greends-pml/blob/main/notebooks/T10b_MNIST_resnet18_adapt_freeze_fine_tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load pre-trained model (`resnet 18`) and fine-tune. Freeze layers for speed-up.

The following script shows how to access a pre-trained `resnet18` model and  view its architecture and the names of its layers in PyTorch.

In [None]:
import torchvision.models as models
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
print(model)

## Changes to be made on the model's design

Note that the input layer is described as `(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)`. For the MNIST input data, there is a single channel (instead of 3). Therefore, the input layer will need to be replaced by `(conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)`.

In addition, since there are only 10 classes (0 to 9) in the MNIST classification problem, the last layer `(fc): Linear(in_features=512, out_features=1000, bias=True)`needs to be adapted to `(fc): Linear(in_features=512, out_features=10, bias=True)`

## Pipeline data preparation, loading and adapting pre-trained model and fine-tuning

Note: if you try to execute the code below, you will see that training takes a long time since we are using the full MNIST train dataset.

In [None]:
import torch
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim

# 1. Data preprocessing: resize to 224x224 and normalize
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 2. Load ResNet and modify layers
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 1 input channel
model.fc = nn.Linear(model.fc.in_features, 10)  # 10 output classes

# 3. Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. Training loop (simplified)
model.train()
for epoch in range(3):
    print('epoch',epoch)
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


## Freezing layers to speed-up fine-tuning

Freezing part of the adapted ResNet18 can speed up fine-tuning and is a common strategy in transfer learning.

- **Why Freeze Layers?**
Early layers of ResNet18 typically learn generic features (edges, textures) that are useful for many vision tasks, including MNIST digit classification. By freezing these layers (setting `requires_grad=False`), you reduce the number of parameters that need updating, which speeds up training and lowers memory usage.
- **Which Layers to Freeze?**
For a simple dataset like MNIST, you can freeze most of the early layers and only fine-tune the later layers and the final classification head. This is because the later layers capture more task-specific features, which are more likely to need adaptation for your new dataset.
- **How to Freeze in PyTorch?**
In practice, we start by freezing all parameters, and then we unfreeze the parameters in the later layers.

In [3]:
for param in model.parameters():
    param.requires_grad = False  # Freeze all layers

# Unfreeze the last block and the classifier
for param in model.layer4.parameters():
    param.requires_grad = True
for param in model.fc.parameters():
    param.requires_grad = True

In short, freezing part of ResNet18 (especially the early layers) is recommended to speed up fine-tuning on MNIST and similar tasks. Fine-tune only the last few layers and the classifier for efficient and effective adaptation.

In practice, one needs to add the code above after loading and adapting the model but **before** setting the **optimizer** since it should only include **trainable parameters** (see script below).

If you include non-trainable (frozen) parameters in the optimizer, the optimizer will waste resources tracking parameters that are not meant to be updated, which can lead to inefficiency and unnecessary memory usage. More critically, if you change which parameters are trainable after creating the optimizer, the optimizer may still try to update the frozen parameters, potentially causing unexpected behavior or errors

### Full pipeline for data preparation, loading and adapting pre-trained model, freezing layers and fine-tuning.

In [None]:
import torch
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim

# 1. Data preprocessing: resize to 224x224 and normalize
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 2. Load ResNet and modify layers
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 1 input channel
model.fc = nn.Linear(model.fc.in_features, 10)  # 10 output classes

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False  # Freeze all layers <<<<<<<<<<< Try False (to not freeze any parameters) and compare the time needed for training

# Unfreeze the last block and the classifier
for param in model.layer4.parameters():
    param.requires_grad = True
for param in model.fc.parameters():
    param.requires_grad = True

# 3. Training setup
criterion = nn.CrossEntropyLoss()
# create optimizer using only trainable parameters
optimizer = torch.optim.Adam(
    [p for p in model.parameters() if p.requires_grad],
    lr=0.001
)

# 4. Training loop
model.train()
for epoch in range(3):
    print('epoch',epoch)
    count=0
    for images, labels in train_loader:
        count+=64
        print(count)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
