# PyTorch Transfer learning

We use PyTorch Lightning to demonstrate transfer learning.

We use both methods of using a pretrained feature extractor and also finetuning.

In [1]:
import os
from pathlib import Path
from typing import Optional
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchmetrics
import torchvision
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import lightning.pytorch as pl
from lightning.pytorch.tuner import Tuner

from torch.utils.data import random_split, DataLoader
from torchvision.datasets.utils import download_and_extract_archive

# check if cuda is usable
if torch.cuda.is_available():
  device = 'cuda' 
else:
  device = 'cpu' 

print(f"Using device: {device}")

Using device: cpu


## Datamodule

In [2]:
class CatDogImageDataModule(pl.LightningDataModule):
    def __init__(self, dl_path='./tmp', batch_size = 32, num_workers=0, cache_dataset=True):
        super().__init__()
        self._dl_path = dl_path
        self.batch_size = batch_size
        self._num_workers = num_workers
        self._cache_dataset = cache_dataset

    @property
    def data_path(self):
        return Path(self._dl_path).joinpath("PetImages")

    @property
    def normalize_transform(self):
        return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
    @property
    def train_transform(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            self.normalize_transform,
        ])

    @property
    def val_transform(self):
        return transforms.Compose([
            transforms.Resize((224, 224)), 
            transforms.ToTensor(), 
            self.normalize_transform
        ])

    def prepare_data(self):
        """Download images and prepare images datasets."""
        url = 'https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip'  
        if not os.path.exists(self.data_path):
            os.makedirs(self.data_path)
        if len(os.listdir(self.data_path)) == 0:
            download_and_extract_archive(url=url, download_root=self._dl_path, remove_finished=not self._cache_dataset)
        else:
            print("Dataset already exists, skipping download and extraction...")

    def setup(self, stage: Optional[str] = None):
        # make assignments here (val/train/test split)
        dataset = self.create_dataset(self.data_path, self.train_transform)
        self.train_data, self.val_data = random_split(dataset, [0.8, 0.2])

        print("Dataset created, split:")
        print(f'training images: {len(self.train_data)}')
        print(f'validation images: {len(self.val_data)}')

    def create_dataset(self, root, transform):
        return ImageFolder(root=root, transform=transform, is_valid_file=self._is_image_valid)

    def train_dataloader(self):
        return DataLoader(dataset=self.train_data, batch_size=self.batch_size, num_workers=self._num_workers, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(dataset=self.val_data, batch_size=self.batch_size, num_workers=self._num_workers, shuffle=False)

    def _is_image_valid(self, image_path):
        try:
            image = Image.open(image_path)
            return True
        except:
            return False

In [3]:
dm = CatDogImageDataModule(num_workers=16, batch_size=128)

In [4]:
# The following methods will be called by the trainer automatically before training:
dm.prepare_data()
#dm.setup()

Using downloaded and verified file: ./tmp/kagglecatsanddogs_5340.zip
Extracting ./tmp/kagglecatsanddogs_5340.zip to ./tmp


## Build Lightning transfer learning model for feature extraction
We use resnet as a backbone here

In [5]:
from torchvision.models import resnet152, ResNet152_Weights

class ImagenetTransferLearnFeatureExtractModule(pl.LightningModule):
    def __init__(self, lr=0.02, num_of_target_classes=2):
        super().__init__()

        self.lr = lr
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes = num_of_target_classes)

        backbone = resnet152(weights=ResNet152_Weights.DEFAULT)
        num_of_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]

        self.feature_extractor = nn.Sequential(*layers)

        self.head = nn.Linear(num_of_filters, num_of_target_classes)

    def cross_entropy_loss(self, logits, labels):
      return F.nll_loss(logits, labels)

    def training_step(self, batch, batch_idx):
        data, label = batch
        output = self.forward(data)
        loss = nn.CrossEntropyLoss()(output,label)
        self.log('train_loss', loss)
        return {'loss': loss, 'log': self.log}

    def validation_step(self, batch, batch_idx):
        val_data, val_label = batch
        val_output = self.forward(val_data)
        val_loss = nn.CrossEntropyLoss()(val_output, val_label)
        self.log('val_loss', val_loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.head(representations)
        return F.softmax(x, dim=1)

In [6]:
if torch.cuda.is_available():
    torch.set_float32_matmul_precision('high')

In [7]:
model = ImagenetTransferLearnFeatureExtractModule()

Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /home/user/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth
100%|██████████| 230M/230M [00:27<00:00, 8.73MB/s] 


## Setup trainer and start training

In [8]:
trainer = pl.Trainer(max_epochs=20)
trainer.fit(model, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Dataset already exists, skipping download and extraction...




Dataset created, split:
training images: 19999
validation images: 4999


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type               | Params
---------------------------------------------------------
0 | accuracy          | MulticlassAccuracy | 0     
1 | feature_extractor | Sequential         | 58.1 M
2 | head              | Linear             | 4.1 K 
---------------------------------------------------------
58.1 M    Trainable params
0         Non-trainable params
58.1 M    Total params
232.592   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


## Build Lightning transfer learning model for fine tuning
We can also fine-tune a pretrained model with our own dataset.
In this case it is important, to only work with small learning rates to not change the weights too much.
Of course we could also freeze a variable number of layers to keep certain feature extractors untouched.

In [5]:
from torchvision.models import resnet152, ResNet152_Weights

class ImagenetTransferLearnFineTuneModule(pl.LightningModule):
    def __init__(self, lr=0.02, num_of_target_classes=2):
        super().__init__()

        self.lr = lr
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes = num_of_target_classes)

        self.backbone = resnet152(weights=ResNet152_Weights.DEFAULT)
        num_of_filters = self.backbone.fc.in_features
        
        self.backbone.fc = nn.Linear(num_of_filters, num_of_target_classes)

    def cross_entropy_loss(self, logits, labels):
      return F.nll_loss(logits, labels)

    def training_step(self, batch, batch_idx):
        data, label = batch
        output = self.forward(data)
        loss = nn.CrossEntropyLoss()(output,label)
        self.log('train_loss', loss)
        return {'loss': loss, 'log': self.log}

    def validation_step(self, batch, batch_idx):
        val_data, val_label = batch
        val_output = self.forward(val_data)
        val_loss = nn.CrossEntropyLoss()(val_output, val_label)
        self.log('val_loss', val_loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, x):
        x = self.backbone.forward(x)
        return F.softmax(x, dim=1)

In [6]:
model = ImagenetTransferLearnFineTuneModule(lr=0.0001)

Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /home/user/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth
100%|██████████| 230M/230M [00:04<00:00, 51.3MB/s] 


In [8]:
trainer = pl.Trainer(max_epochs=20)
trainer.fit(model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Dataset already exists, skipping download and extraction...



  | Name     | Type               | Params
------------------------------------------------
0 | accuracy | MulticlassAccuracy | 0     
1 | backbone | ResNet             | 58.1 M
------------------------------------------------
58.1 M    Trainable params
0         Non-trainable params
58.1 M    Total params
232.592   Total estimated model params size (MB)


Dataset created, split:
training images: 19999
validation images: 4999


Sanity Checking: 0it [00:00, ?it/s]

OSError: Caught OSError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/user/mambaforge/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/user/mambaforge/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user/mambaforge/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user/mambaforge/lib/python3.9/site-packages/torch/utils/data/dataset.py", line 298, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/home/user/mambaforge/lib/python3.9/site-packages/torchvision/datasets/folder.py", line 229, in __getitem__
    sample = self.loader(path)
  File "/home/user/mambaforge/lib/python3.9/site-packages/torchvision/datasets/folder.py", line 268, in default_loader
    return pil_loader(path)
  File "/home/user/mambaforge/lib/python3.9/site-packages/torchvision/datasets/folder.py", line 246, in pil_loader
    with open(path, "rb") as f:
OSError: [Errno 5] Input/output error: 'tmp/PetImages/Dog/11689.jpg'


In [2]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs --bind_all