Run locally or <a target="_blank" href="https://colab.research.google.com/github/aalgahmi/dl_handouts/blob/main/10.transfer_learning.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
!pip install -q torchinfo torchviz lightning opendatasets
!mkdir checkpoints

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m69.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m54.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m40.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Transfer learning

Transfer learning is a technique where a model developed for a specific task is reused as the starting point for another model on a second (but somehow related) task. It involves taking a pre-trained model, which has already learned features from a large dataset, and fine-tuning it for a different but related task. This is particularly useful when the second task has limited labeled data, as the pre-trained model can leverage its learned features to generalize well on the new task with less data.

There are typically two main approaches to transfer learning:

* **Feature Extraction:** In this approach, the pre-trained model is used as a fixed feature extractor. The weights of the pre-trained layers are frozen, and only the final layers are modified and trained on the new task. In other words, the pre-trained model can be thought of as having both a base and a top. The base is frozen to prevent the backpropagation algorithm from changing its trained parameters, and the top is replaced with a new one, and is the only part of the model trained on the new data.

* **Fine-tuning:** In this approach, the pre-trained model is further trained on the new task, and the weights of some or all layers are updated during training. This allows the model to adapt to the specific characteristics of the new dataset.

Transfer learning is widely used in computer vision and natural language processing, where large pre-trained models, such as ConvNets for images or pre-trained language models for text, are fine-tuned for specific applications.

This strategy is often employed when there isn't enough data, time, and/or resources to train a full-scale model from scratch. The `torchvision.models` library comes with many pre-trained computer vision models. You can find a list of these pre-trained models [here](https://pytorch.org/vision/0.9/models.html).

This notebook applies transfer learning to the two examples of the previous handout.

But first we get the data. You'll need your Kaggle username and key for this. You can ignore this step if you already have it:

In [2]:
import opendatasets as od

data_path = './datasets'

dataset_url = 'https://www.kaggle.com/c/dogs-vs-cats/data'
od.download(dataset_url, data_dir=data_path)


Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: aalgahmi4dl
Your Kaggle Key: ··········
Downloading dogs-vs-cats.zip to ./datasets/dogs-vs-cats


100%|██████████| 812M/812M [00:07<00:00, 114MB/s]



Extracting archive ./datasets/dogs-vs-cats/dogs-vs-cats.zip to ./datasets/dogs-vs-cats


In [3]:
!cd ./datasets/dogs-vs-cats/ && unzip -qq train.zip && cd -

/content


In [4]:
import os, shutil, pathlib

src_dir = pathlib.Path("./datasets/dogs-vs-cats/train")
dest_dir = pathlib.Path("./datasets/dogs-vs-cats/processed")

if not os.path.exists(dest_dir):
    def make_subset(subset_name, start_index, end_index):
        for category in ("cat", "dog"):
            dir = dest_dir / subset_name / category
            os.makedirs(dir)
            fnames = [ f"{category}.{i}.jpg" for i in range(start_index, end_index) ]
            for fname in fnames:
                shutil.copyfile(src=src_dir / fname,
                                dst=dir / fname)

    make_subset("train", start_index=0, end_index=8000)
    make_subset("validation", start_index=8000, end_index=9500)
    make_subset("test", start_index=9500, end_index=11000)
else:
    print("Skipping!", dest_dir, "already exists.")

Skipping, found downloaded files in "./datasets/dogs-vs-cats" (use force=True to force download)


In [7]:
import pandas as pd
import torch
dataset_url = 'https://www.kaggle.com/c/histopathologic-cancer-detection'
od.download(dataset_url, data_dir='./datasets')

def make_subset(subset, start_index, end_index, images, labels):
    categories = {0: "0_normal", 1: "1_abnormal" }
    for i in range(start_index, end_index):
        category = categories[labels[i]]
        dir = dest_dir / subset / category
        os.makedirs(dir, exist_ok=True)
        fname = f"{images[i]}.tif"
        shutil.copyfile(src=src_dir / fname, dst=dir / fname)

data_path = './datasets/histopathologic-cancer-detection'
labels = pd.read_csv(data_path + '/train_labels.csv')
labels = labels.set_index('id')

src_dir = pathlib.Path(data_path + "/train")
dest_dir = pathlib.Path(data_path + "/processed")

train_image_files = os.listdir(src_dir)
selected_images = [
    train_image_files[f].split('.')[0]
    for f in torch.randperm(len(train_image_files))[:20000]
]

selected_labels = labels.loc[selected_images]['label'].values

len(selected_images), len(selected_labels)

if not os.path.exists(dest_dir):
    make_subset("train", 0, 16000, selected_images, selected_labels)
    make_subset("validation", 16000, 18000, selected_images, selected_labels)
    make_subset("test", 18000, 20000, selected_images, selected_labels)
else:
    print("Skipping!", dest_dir, "already exists.")

Skipping, found downloaded files in "./datasets/histopathologic-cancer-detection" (use force=True to force download)


## Dogs vs cats

One of the pretrained models implemented by `torchvision.models` is the VGG16 ConvNet. VGG16 is a deep convolutional neural network architecture introduced by the Visual Geometry Group at the University of Oxford and presented in a paper from 2015 titled 'Very Deep Convolutional Networks for Large-Scale Image Recognition' by Karen Simonyan and Andrew Zisserman. It is known for its simplicity and uniform architecture, and consists of 16 weight layers, including 13 convolutional layers and 3 fully connected layers.

Let's use VGG16 with the dogs vs. cats example, starting by reusing the data module from the previous notebook.

In [8]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.functional as F
import torchvision
from torchvision import datasets, models, transforms
import lightning as L
from torchmetrics import Accuracy
from torch.utils.data import random_split, DataLoader
from torchinfo import summary
import matplotlib.pyplot as plt
import opendatasets as od
import os, shutil, pathlib

torch.random.manual_seed(17);

class DogsVsCatsDataModule(L.LightningDataModule):
    def __init__(self, data_path='./datasets', transform = transforms.Compose([
            transforms.Resize(size=(128, 128)),
            transforms.ToTensor()
        ]), train_transform=None):
        super().__init__()

        self.data_path = data_path
        self.transform = transform
        self.train_transform = transform if train_transform is None else train_transform

    def make_subset(self, subset_name, start_index, end_index):
        for category in ("cat", "dog"):
            dir = self.dest_dir / subset_name / category
            os.makedirs(dir)
            fnames = [ f"{category}.{i}.jpg" for i in range(start_index, end_index) ]
            for fname in fnames:
                shutil.copyfile(src=self.src_dir / fname, dst=dir / fname)

    def prepare_data(self):
        dataset_url = 'https://www.kaggle.com/c/dogs-vs-cats/data'
        od.download(dataset_url, data_dir=self.data_path)

        self.src_dir = pathlib.Path(self.data_path + "/dogs-vs-cats/train")
        self.dest_dir = pathlib.Path(self.data_path + "/dogs-vs-cats/processed")

        if not os.path.exists(self.dest_dir):
            self.make_subset("train", start_index=0, end_index=8000)
            self.make_subset("validation", start_index=8000, end_index=9500)
            self.make_subset("test", start_index=9500, end_index=11000)
        else:
            print("Skipping!", self.dest_dir, "already exists.")

    def setup(self, stage=None):
        self.target_transform = transforms.Lambda(lambda y: torch.tensor([y]).float())

        self.ds_train = datasets.ImageFolder(f"{self.dest_dir}/train", transform=self.train_transform)
        self.ds_val = datasets.ImageFolder(f"{self.dest_dir}/validation", transform=self.transform)
        self.ds_test = datasets.ImageFolder(f"{self.dest_dir}/test", transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=128, num_workers=4, shuffle=True,
                          persistent_workers=True, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=128, num_workers=4, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.ds_test, batch_size=128, num_workers=4, persistent_workers=True)


Next we download and instantiate a pre-trained vgg16 model.

In [9]:
vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:09<00:00, 60.8MB/s]


Here is a summary of it:

In [10]:
summary(vgg16)

Layer (type:depth-idx)                   Param #
VGG                                      --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       1,792
│    └─ReLU: 2-2                         --
│    └─Conv2d: 2-3                       36,928
│    └─ReLU: 2-4                         --
│    └─MaxPool2d: 2-5                    --
│    └─Conv2d: 2-6                       73,856
│    └─ReLU: 2-7                         --
│    └─Conv2d: 2-8                       147,584
│    └─ReLU: 2-9                         --
│    └─MaxPool2d: 2-10                   --
│    └─Conv2d: 2-11                      295,168
│    └─ReLU: 2-12                        --
│    └─Conv2d: 2-13                      590,080
│    └─ReLU: 2-14                        --
│    └─Conv2d: 2-15                      590,080
│    └─ReLU: 2-16                        --
│    └─MaxPool2d: 2-17                   --
│    └─Conv2d: 2-18                      1,180,160
│    └─ReLU: 2-19                

As you can see, there are two main big blocks separated by an `AdaptiveAvgPool2d` layer. You can think of the first sequential block as the base (or feature extractor) and of the second sequential block as the top (or classifier). Transfer learning involves changing/replacing the top block to fit the new task at hand. To do that, we freeze the network first.

In [11]:
for param in vgg16.parameters():
    param.requires_grad = False

This makes sure that none of the pre-trained weights and biases will be affected when the model is trained again on the new data. Next, we change or replace the classifier block, which looks like this:

In [12]:
vgg16.classifier

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)

To replace the last layer of the classifier with one that has two units instead of the original 1000, we do something like this:

In [13]:
vgg16.classifier[6] = nn.Linear(in_features=4096, out_features=2, bias=True)
vgg16.classifier

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=2, bias=True)
)

And here is how to replace the while classifier block with a new untrained one:

In [14]:
vgg16.classifier = nn.Sequential(
    nn.Linear(in_features=25088, out_features=256, bias=True),
    nn.ReLU(),
    nn.Dropout(p=0.5, inplace=False),
    nn.Linear(in_features=256, out_features=2, bias=True)
)

vgg16

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

Now that we know how to do that, let's put it all inside a Lightning module.

In [15]:
class VGG16BasedClassifier(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.train_accuracy = Accuracy(task="multiclass", num_classes=2)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=2)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=2)

        self.pretrained_model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

        self.pretrained_model.eval()
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

        self.pretrained_model.classifier = nn.Sequential(
            nn.Linear(in_features=25088, out_features=256, bias=True),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=256, out_features=2, bias=True)
        )

    def forward(self, x):
        return self.pretrained_model(x)

    def _common_step(self, batch, batch_idx, accuracy, loss_lbl, accuracy_lbl):
        X, y = batch
        logits = self(X)
        loss = nn.functional.cross_entropy(logits, y)
        y_hat = torch.argmax(logits, dim=1)
        self.log(loss_lbl, loss, prog_bar=True)
        self.log(accuracy_lbl, accuracy(y_hat, y), prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, self.train_accuracy, "loss", "accuracy")

    def on_training_epoch_end(self):
        self.log("accuracy", self.train_accuracy.compute())
        self.train_accuracy.reset()

    def validation_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, self.val_accuracy, "val_loss", "val_accuracy")

    def on_validation_epoch_end(self):
        self.log("val_accuracy", self.val_accuracy.compute())
        self.val_accuracy.reset()

    def test_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, self.test_accuracy, "test_loss", "test_accuracy")

    def on_test_epoch_end(self):
        self.log("test_accuracy", self.test_accuracy.compute())
        self.test_accuracy.reset()

    def configure_optimizers(self):
        return torch.optim.RMSprop(self.parameters(), lr=1e-4)


Using this class, here is a new model:

In [16]:
vgg16_based_model = VGG16BasedClassifier()

Let's summarize it:

In [17]:
summary(vgg16_based_model)

Layer (type:depth-idx)                   Param #
VGG16BasedClassifier                     --
├─MulticlassAccuracy: 1-1                --
├─MulticlassAccuracy: 1-2                --
├─MulticlassAccuracy: 1-3                --
├─VGG: 1-4                               --
│    └─Sequential: 2-1                   --
│    │    └─Conv2d: 3-1                  (1,792)
│    │    └─ReLU: 3-2                    --
│    │    └─Conv2d: 3-3                  (36,928)
│    │    └─ReLU: 3-4                    --
│    │    └─MaxPool2d: 3-5               --
│    │    └─Conv2d: 3-6                  (73,856)
│    │    └─ReLU: 3-7                    --
│    │    └─Conv2d: 3-8                  (147,584)
│    │    └─ReLU: 3-9                    --
│    │    └─MaxPool2d: 3-10              --
│    │    └─Conv2d: 3-11                 (295,168)
│    │    └─ReLU: 3-12                   --
│    │    └─Conv2d: 3-13                 (590,080)
│    │    └─ReLU: 3-14                   --
│    │    └─Conv2d: 3-15         

We can now train this model. When we do so, only the replaced classifier block will be affected. The rest is frozen. We'll use the early stopping callback.

In [None]:
from lightning.pytorch.callbacks import EarlyStopping

dogs_vs_cats_dm = DogsVsCatsDataModule()

trainer = L.Trainer(max_epochs=5, callbacks=[
    EarlyStopping(monitor='val_loss', patience=3, mode='min')
])
trainer.fit(vgg16_based_model, datamodule=dogs_vs_cats_dm)

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


Skipping, found downloaded files in "./datasets/dogs-vs-cats" (use force=True to force download)
Skipping! datasets/dogs-vs-cats/processed already exists.


INFO: 
  | Name             | Type               | Params | Mode 
----------------------------------------------------------------
0 | train_accuracy   | MulticlassAccuracy | 0      | train
1 | val_accuracy     | MulticlassAccuracy | 0      | train
2 | test_accuracy    | MulticlassAccuracy | 0      | train
3 | pretrained_model | VGG                | 21.1 M | eval 
----------------------------------------------------------------
6.4 M     Trainable params
14.7 M    Non-trainable params
21.1 M    Total params
84.552    Total estimated model params size (MB)
8         Modules in train mode
34        Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name             | Type               | Params | Mode 
----------------------------------------------------------------
0 | train_accuracy   | MulticlassAccuracy | 0      | train
1 | val_accuracy     | MulticlassAccuracy | 0      | train
2 | test_accuracy    | MulticlassAccuracy | 0      | train
3 | pretrained_model | VGG

Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
from lightning.pytorch.callbacks import EarlyStopping

dogs_vs_cats_dm = DogsVsCatsDataModule()

trainer = L.Trainer(max_epochs=5, callbacks=[
    EarlyStopping(monitor='val_loss', patience=3, mode='min')
])
trainer.fit(vgg16_based_model, datamodule=dogs_vs_cats_dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Skipping, found downloaded files in "./datasets/dogs-vs-cats" (use force=True to force download)
Skipping! datasets/dogs-vs-cats/processed already exists.



  | Name             | Type               | Params
--------------------------------------------------------
0 | train_accuracy   | MulticlassAccuracy | 0     
1 | val_accuracy     | MulticlassAccuracy | 0     
2 | test_accuracy    | MulticlassAccuracy | 0     
3 | pretrained_model | VGG                | 21.1 M
--------------------------------------------------------
6.4 M     Trainable params
14.7 M    Non-trainable params
21.1 M    Total params
84.552    Total estimated model params size (MB)


Sanity Checking: |                                                                                   | 0/? [00…

Training: |                                                                                          | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

`Trainer.fit` stopped: `max_epochs=5` reached.


Let's evaluate this model:

In [None]:
trainer.test(vgg16_based_model, datamodule=dogs_vs_cats_dm)

Skipping, found downloaded files in "./datasets/dogs-vs-cats" (use force=True to force download)
Skipping! datasets/dogs-vs-cats/processed already exists.


Testing: |                                                                                           | 0/? [00…

[{'test_loss': 0.14621175825595856, 'test_accuracy': 0.9393333196640015}]

As you can see, this model significantly outperforms our previous from-scratch models by capitalizing on a pre-trained VGG16 model. The VGG16 model has been previously trained on the extensive ImageNet dataset, a widely used dataset for training and evaluating computer vision models, especially for image classification tasks. Given the richness of cat and dog images in the ImageNet dataset, this pretrained model proves highly effective in our specific task.



### Fine-tuning a pretrained model
As an optional, but widely used, step, we can improve our this pre-trained model by fine-tuning it. This is done by unfreezing all or part of the pre-trained base block and re-training it on the new data with a very slow learning rate. The whole process translates to the following steps:
* Instantiate the base model
* Freeze it
* Add a new top to it
* Train the part we added
* Unfreeze some layers on the base model. Don't unfreeze any batch normalization layer.
* Jointly train both the new unfrozen part of the base and the top part of the model using a very slow learning rate.

We already did the top four steps. Let's do the last two. First we unfreeze the top two convolutional layers of the base block.

In [None]:
for i in [26, 28]:
    for param in vgg16_based_model.pretrained_model.features[i].parameters():
        param.requires_grad = True

Having unfrozen these two layers, we train the model again with a slower learning rate.

In [None]:
vgg16_based_model.lr = 0.0001

trainer = L.Trainer(max_epochs=2, callbacks=[
    EarlyStopping(monitor='val_loss', patience=3, mode='min')
])
trainer.fit(vgg16_based_model, datamodule=dogs_vs_cats_dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name             | Type               | Params
--------------------------------------------------------
0 | train_accuracy   | MulticlassAccuracy | 0     
1 | val_accuracy     | MulticlassAccuracy | 0     
2 | test_accuracy    | MulticlassAccuracy | 0     
3 | pretrained_model | VGG                | 21.1 M
--------------------------------------------------------
11.1 M    Trainable params
10.0 M    Non-trainable params
21.1 M    Total params
84.552    Total estimated model params size (MB)


Skipping, found downloaded files in "./datasets/dogs-vs-cats" (use force=True to force download)
Skipping! datasets/dogs-vs-cats/processed already exists.


Sanity Checking: |                                                                                   | 0/? [00…

Training: |                                                                                          | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

`Trainer.fit` stopped: `max_epochs=2` reached.


Finally, let's evaluate this fine-tuned model:

In [None]:
trainer.test(vgg16_based_model, datamodule=dogs_vs_cats_dm)

Skipping, found downloaded files in "./datasets/dogs-vs-cats" (use force=True to force download)
Skipping! datasets/dogs-vs-cats/processed already exists.


Testing: |                                                                                           | 0/? [00…

[{'test_loss': 0.1642094999551773, 'test_accuracy': 0.9380000233650208}]

## Cancer tissue detection
As a second example, let's use another popular pre-trained network for detecting cancer tissues using the PCam dataset. ResNet50 is a variant of the ResNet (Residual Network) architecture, a deep neural network architecture that introduced residual blocks, which help address the challenges of training very deep neural networks. ResNet50 consists of 50 layers, making it a relatively deep neural network. Due to its success and efficiency, it serves as a benchmark model in the field of deep learning.

Let's download and summarize a pre-trained ResNet50 model from `torchvision.models`:

In [None]:
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
summary(resnet50)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            9,408
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─Bottleneck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─Bottleneck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13               

As you can see, this is a deeper neural network. We can also print it to get the indexes and names of the layers and blocks that make it. This will be useful when we make changes to its top layer(s).

In [None]:
print(resnet50)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

Looking at the last layer, its name is `fc` and it consists of 1000 units because of the 1000 classes of the ImageNet dataset it was trained on. This is the layer we need to replace to make this network work for our PCam dataset. Let's get started.

### Fetching and preparing the data

We'll start by fetching and preparing the data, leveraging some of the code we had in the previous notebook.

In [None]:
def make_subset(subset, start_index, end_index, images, labels):
    categories = {0: "0_normal", 1: "1_abnormal" }
    for i in range(start_index, end_index):
        category = categories[labels[i]]
        dir = dest_dir / subset / category
        os.makedirs(dir, exist_ok=True)
        fname = f"{images[i]}.tif"
        shutil.copyfile(src=src_dir / fname, dst=dir / fname)

# Downloading the data from kaggle if needed
dataset_url = 'https://www.kaggle.com/c/histopathologic-cancer-detection'
od.download(dataset_url, data_dir='./datasets')

# Getting the actual labels
data_path = './datasets/histopathologic-cancer-detection'
labels = pd.read_csv(data_path + '/train_labels.csv')
labels = labels.set_index('id')

# Selecting random 20,000 images and splitting them into three sets
src_dir = pathlib.Path(data_path + "/train")
dest_dir = pathlib.Path(data_path + "/processed")

train_image_files = os.listdir(src_dir)
selected_images = [
    train_image_files[f].split('.')[0]
    for f in torch.randperm(len(train_image_files))[:20000]
]

selected_labels = labels.loc[selected_images]['label'].values

if not os.path.exists(dest_dir):
    make_subset("train", 0, 16000, selected_images, selected_labels)
    make_subset("validation", 16000, 18000, selected_images, selected_labels)
    make_subset("test", 18000, 20000, selected_images, selected_labels)
else:
    print("Skipping!", dest_dir, "already exists.")

# Creating the datasets
transform = transforms.Compose([
    transforms.CenterCrop(32),
    transforms.ToTensor()
])

train_transform = transforms.Compose([
    transforms.CenterCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
])

ds_train = datasets.ImageFolder(f"{dest_dir}/train", transform=train_transform)
ds_val = datasets.ImageFolder(f"{dest_dir}/validation", transform=transform)
ds_test = datasets.ImageFolder(f"{dest_dir}/test", transform=transform)

# Creating the corresponding data loaders
dl_train = DataLoader(ds_train, batch_size=256, num_workers=4,
                      shuffle=True, persistent_workers=True,
                      drop_last=True, pin_memory=True)
dl_val = DataLoader(ds_val, batch_size=256, num_workers=4,
                    persistent_workers=True, pin_memory=True)
dl_test = DataLoader(ds_test, batch_size=256, num_workers=4,
                     persistent_workers=True, pin_memory=True)

Skipping, found downloaded files in "./datasets/histopathologic-cancer-detection" (use force=True to force download)
Skipping! datasets/histopathologic-cancer-detection/processed already exists.


### Using a ResNet50 model

Next, we construct a Lightning model based on ResNet50. Similar to our previous approach, once we download and instantiate the pretrained ResNet50, we freeze all its parameters. This prevents the pretrained weights from being updated during the training on the PCam dataset. Finally, we replace the last (top) linear layer, named `fc`, with a new one containing only 2 units instead of the original 1000.

In [None]:
class ResNet50BasedClassifier(L.LightningModule):

    def __init__(self, lr=0.001):
        super().__init__()
        self.lr = lr

        self.train_accuracy = Accuracy(task="multiclass", num_classes=2)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=2)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=2)

        self.pretrained_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        self.pretrained_model.eval()
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

        self.pretrained_model.fc = nn.Linear(2048, 2)

    def forward(self, x):
        return self.pretrained_model(x)

    def _common_step(self, batch, batch_idx, accuracy, loss_lbl, accuracy_lbl):
        X, y = batch
        logits = self(X)
        loss = nn.functional.cross_entropy(logits, y)
        y_hat = torch.argmax(logits, dim=1)
        self.log(loss_lbl, loss, prog_bar=True)
        self.log(accuracy_lbl, accuracy(y_hat, y), prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, self.train_accuracy, "loss", "accuracy")

    def on_training_epoch_end(self):
        self.log("accuracy", self.train_accuracy.compute())
        self.train_accuracy.reset()

    def validation_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, self.val_accuracy, "val_loss", "val_accuracy")

    def on_validation_epoch_end(self):
        self.log("val_accuracy", self.val_accuracy.compute())
        self.val_accuracy.reset()

    def test_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, self.test_accuracy, "test_loss", "test_accuracy")

    def on_test_epoch_end(self):
        self.log("test_accuracy", self.test_accuracy.compute())
        self.test_accuracy.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


Now, we can train this model on the PCam dataset. It's important to note that we should not expect the same level of performance as in the previous example because the PCam data differs significantly from the ImageNet dataset on which the base ResNet50 model is trained. However with enough epochs, we still anticipate good results.

In [None]:
resnet50_based_model = ResNet50BasedClassifier()

trainer = L.Trainer(max_epochs=15, callbacks=[
    EarlyStopping(monitor='val_loss', patience=3, mode='min')
])

trainer.fit(resnet50_based_model, train_dataloaders=dl_train, val_dataloaders=dl_val)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name             | Type               | Params
--------------------------------------------------------
0 | train_accuracy   | MulticlassAccuracy | 0     
1 | val_accuracy     | MulticlassAccuracy | 0     
2 | test_accuracy    | MulticlassAccuracy | 0     
3 | pretrained_model | ResNet             | 23.5 M
--------------------------------------------------------
4.1 K     Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.049    Total estimated model params size (MB)


Sanity Checking: |                                                                                   | 0/? [00…

Training: |                                                                                          | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Validation: |                                                                                        | 0/? [00…

Finally, we evaluate this new model.

In [None]:
trainer.test(resnet50_based_model, dataloaders=dl_test)

Testing: |                                                                                           | 0/? [00…

[{'test_loss': 0.47510087490081787, 'test_accuracy': 0.7854999899864197}]

These results are comparable, if not slightly better, than the ones obtained from models trained from scratch. This suggests that even with a substantially different dataset, such as PCam, transfer learning remains a powerful and effective technique.

Happy learning!