# Fine-Tuning with PyTorch

Throughout this notebook, we use the <a href='https://www.kaggle.com/c/cifar-10/overview'>CIFAR-10</a> dataset from Kaggle, a popular computer-vision dataset of 60,000 32x32 color images to be classified in one of ten classes, with 6000 images per class. This dataset is complex enough to give a good idea of the benefits of fine-tuning, and why this process can be used to achieve high accuracy scores without spending hours and money to re-train complex models from scratch. With **default hyperparameters** and a medium-sized model (ResNet34), we are able to achieve around **96% accuracy** on Kaggle, 3% off the state-of-the-art for this dataset.

# Table of Contents
1. [Extracting Data](#extraction)
2. [Datasets and DataLoaders](#data)  
3. [Training with Validation](#validation)  
4. [Full Training](#training)
5. [Generating Predictions](#testing) 

## Extracting Data <a name="extraction"></a>
The original CIFAR-10 dataset provided by Kaggle is composed of two *.7z* files (*train.7z*, *test.7z*), the labels (*trainLabels.csv*) and an example submission (*sampleSubmission.csv*). We provide these files in the *data* folder.  

First, we extract the two zipped files in two folers called *original_train* and *original_test*. Then, we move these files to be in the structure required by the `ImageDataset` class of PyTorch, where every image is stored in a folder named as its label (e.g., the nth airplane image will be stored under *airplane/n.png*). We will create four subfolders within *data*:
* *train*: contains Training data (excluding Validation set), used to train the model during hyperparameter search
* *valid*: contains Validation data, used to train the model during hyperparameter search
* *train_valid*: contains Training+Validation data together, used for full re-training of the model
* *test* contains Test data, i.e., all the unlabelled data we must submit to Kaggle together with a predicted label

In [None]:
# py7zr is required to extract the .7tz files

!pip install -q py7zr

In [None]:
import py7zr
import shutil
from pathlib import Path

root = Path('./data')
input_path = Path('../input/cifar-10')

with py7zr.SevenZipFile(input_path/'train.7z', mode='r') as z:
    z.extractall(root)

with py7zr.SevenZipFile(input_path/'test.7z', mode='r') as z:
    z.extractall(root)

shutil.copy(input_path/'trainLabels.csv', root/'trainLabels.csv')

(root/'train').rename(root/'original_train')
(root/'test').rename(root/'original_test')

In [None]:
from random import random
import os

def copy_file(source_directory, destination_directory, filename):
    """
    Utility function used to copy a file from a source_directory to a destination_directory
    """
    destination_directory.mkdir(parents=True, exist_ok=True)
    shutil.copy(source_directory/filename, destination_directory/filename)
    
def organize_train_valid_dataset(root, labels, valid_probability=0.1):
    """
    Creates the train, train_valid and valid folders respecting PyTorch's ImageDataset structure, performing
    train/validation split based on the given percentage
    """
    source_directory = root/'original_train'
    
    with os.scandir(source_directory) as it:
        for entry in it:
            if entry.is_file():
                img_index = entry.name.split('.')[0]  # The index is the name of the image except the extension
                img_class = labels[labels.id==int(img_index)].label.values[0]  # Find the class by looking up the index in the DF
                
                # Randomly assign the image to the valid dataset with probability 'valid_probability'
                channel = Path('train') if random()>valid_probability else Path('valid')
                destination_directory = root/channel/img_class
                
                # Copy the image to either the train or valid folder, and also to the train_valid folder
                copy_file(source_directory, destination_directory, entry.name)
                copy_file(source_directory, root/'train_valid'/img_class, entry.name)

def organize_test_dataset(root):
    """
    Creates the test folder respecting PyTorch's ImageDataset structure, using a dummy 'undefined' label
    """
    source_directory = root/'original_test'
        
    with os.scandir(source_directory) as it:
        for entry in it:
            if entry.is_file():
                img_index = entry.name.split('.')[0]  # The index is the name of the image except the extension

                channel = Path('test')
                destination_directory = root/channel/'undefined'

                copy_file(source_directory, destination_directory, entry.name)

In [None]:
import pandas as pd

# Read in the labels DataFrame with a label for each image
labels = pd.read_csv(root/'trainLabels.csv')

# Create the train/train_valid/valid folder structure
valid_probability = 0.1
organize_train_valid_dataset(root, labels, valid_probability)

# Create the test folder structure
organize_test_dataset(root)

## Datasets and DataLoaders <a name='data'></a>

As mentioned above, we rely on the `ImageDataset` class of PyTorch to create the required datasets for training, validation, training+validation and testing. Out of each dataset, we then create a DataLoader to be used in the training/evaluation loops to efficiently fetch images in batches from disk.

We perform an initial step to load in the train data and compute the mean and standard deviation of the dataset for each channel (R, G, B), across all images and all pixels. We compute a mean and stdev value batch-by-batch to avoid loading the entire dataset in memory, and then compute the mean of the means and of the stdevs.  
**NOTE**: if you have enough RAM (or memory on the GPU), you can use a batch_size equal to the entire train_dataset length, it will provide a more accurate estimation of the means and stdevs by channel.

In [None]:
!pip install -q --upgrade torchvision

In [None]:
import torchvision
import torch

train_dataset = torchvision.datasets.ImageFolder(
    root/'train', 
    transform=torchvision.transforms.Compose([
        # Resize step is required as we will use a ResNet model, which accepts at leats 224x224 images
        torchvision.transforms.Resize((224,224)),  
        torchvision.transforms.ToTensor(),
    ])
)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)

means = []
stdevs = []
for X, _ in train_dataloader:
    # Dimensions 0,2,3 are respectively the batch, height and width dimensions
    means.append(X.mean(dim=(0,2,3)))
    stdevs.append(X.std(dim=(0,2,3)))

mean = torch.stack(means, dim=0).mean(dim=0)
stdev = torch.stack(stdevs, dim=0).mean(dim=0)

The transforms used for the training and training+validation datasets consist of resizing the images to the required resolution by our ResNet model (224x224), using the `AutoAugment` policy learned on the CIFAR10 dataset and finally converting the image from PIL to Tensor. For the validation and test sets we just resize the image and convert it to Tensor format.

In [None]:
train_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224)),
        torchvision.transforms.AutoAugment(policy=torchvision.transforms.AutoAugmentPolicy.CIFAR10),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, stdev)
    ])

train_dataset, train_valid_dataset = [torchvision.datasets.ImageFolder(folder, transform=train_transforms) for folder in [root/'train', root/'train_valid']]


valid_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, stdev)
    ])

valid_dataset, test_dataset = [torchvision.datasets.ImageFolder(folder, transform=valid_transforms) for folder in [root/'valid', root/'test']]

The train and train+validation DataLoaders use a smaller `batch_size` as we will also need to keep track of gradients in memory. Furthermore, we shuffle the dataset each epoch to avoid loading the batches in the same order.
The valid and test DataLoaders use a larger `batch_size` and do not required to shuffle the dataset as we want deterministic results.

The number of workers is generally set to `2 * num_gpus` as a rule of thumb for Kaggle, with `pin_memory = True` to speed up data transfer to the GPU.

In [None]:
num_gpus = torch.cuda.device_count()

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2*num_gpus, pin_memory=True)
train_valid_dataloader = torch.utils.data.DataLoader(train_valid_dataset, batch_size=128, shuffle=True, num_workers=2*num_gpus, pin_memory=True)

valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=256, shuffle=False, num_workers=2*num_gpus, pin_memory=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2*num_gpus, pin_memory=True)

## Training with Validation <a name='validation'></a>

The first step of the process is to evaluate the model performance on our own Validation set, consisting of 10% of the labelled data we get from Kaggle. Ideally, this step would be performed while finding the best model and hyperparameters to improve the final accuracy. Here, we just perform this step to show the expected model accuracy before submitting the results to Kaggle.  
**NOTE**: when doing proper hyperparameter tuning, depending on the size of the overall labelle data, a k-fold approach might be more appropriate to estimate the generalization capability of the model.

We fine-tune a ResNet34 model, trained on ImageNet. Other models might be used, but for the purpose of this notebook a ResNet34 is a good trade-off between training time and model accuracy.  
The model originally has a 1000-dimensional output layer, but our dataset has only 10 classes, so we remove the output layer and define a new Fully-Connected layer with just 10 neurons, one for each class in CIFAR-10. The parameters of these new neurons are initialized with Xavier initialization.

In [None]:
def get_net():
    resnet = torchvision.models.resnet34(pretrained=True)
    
    # Substitute the FC output layer
    resnet.fc = torch.nn.Linear(resnet.fc.in_features, 10)
    torch.nn.init.xavier_uniform_(resnet.fc.weight)
    return resnet

The training loop is a standard PyTorch loop where for every epoch we perform the following macro steps:
1. Iterate over the Train DataLoader by making predictions, calculating loss, backpropagating gradients and updating parameters
2. Iterate over the Valid DataLoader (if present) to compute the validation loss and accuracy
3. Decrease the learning rate using the scheduler (if present)
4. Optionally, store the model checkpoint after a given number of `checkpoint_epochs`

In [None]:
import time

def train(net, train_dataloader, valid_dataloader, criterion, optimizer, scheduler=None, epochs=10, device='cpu', checkpoint_epochs=10):
    start = time.time()
    print(f'Training for {epochs} epochs on {device}')
    
    for epoch in range(1,epochs+1):
        print(f"Epoch {epoch}/{epochs}")
        
        net.train()  # put network in train mode for Dropout and Batch Normalization
        train_loss = torch.tensor(0., device=device)  # loss and accuracy tensors are on the GPU to avoid data transfers
        train_accuracy = torch.tensor(0., device=device)
        for X, y in train_dataloader:
            X = X.to(device)
            y = y.to(device)
            preds = net(X)
            loss = criterion(preds, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                train_loss += loss * train_dataloader.batch_size
                train_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if valid_dataloader is not None:
            net.eval()  # put network in train mode for Dropout and Batch Normalization
            valid_loss = torch.tensor(0., device=device)
            valid_accuracy = torch.tensor(0., device=device)
            with torch.no_grad():
                for X, y in valid_dataloader:
                    X = X.to(device)
                    y = y.to(device)
                    preds = net(X)
                    loss = criterion(preds, y)

                    valid_loss += loss * valid_dataloader.batch_size
                    valid_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if scheduler is not None: 
            scheduler.step()
            
        print(f'Training loss: {train_loss/len(train_dataloader.dataset):.2f}')
        print(f'Training accuracy: {100*train_accuracy/len(train_dataloader.dataset):.2f}')
        
        if valid_dataloader is not None:
            print(f'Valid loss: {valid_loss/len(valid_dataloader.dataset):.2f}')
            print(f'Valid accuracy: {100*valid_accuracy/len(valid_dataloader.dataset):.2f}')
        
        if epoch%checkpoint_epochs==0:
            torch.save({
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, './checkpoint.pth.tar')
        
        print()
    
    end = time.time()
    print(f'Total training time: {end-start:.1f} seconds')
    return net

In this notebook we only use of at most one GPU, you can freely refactor the code to use DistributedDataParallel if you have more GPUs and/or devices.  

When fine-tuning, the model parameters of the network body are trained using a lower learning rate than for the head, since for the latter we have to train them from scratch. We rely on Parameter Groups from PyTorch to define two learning rates for the two groups, and use Adam optimizer with `weight_decay = 5e-4` (find via hyperparameter search).

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr, weight_decay, epochs = 1e-5, 5e-4, 20

net = get_net().to(device)

# Standard CrossEntropy Loss for multi-class classification problems
criterion = torch.nn.CrossEntropyLoss()

# params_1x are the parameters of the network body, i.e., of all layers except the FC layers
params_1x = [param for name, param in net.named_parameters() if 'fc' not in str(name)]
optimizer = torch.optim.Adam([{'params':params_1x}, {'params': net.fc.parameters(), 'lr': lr*10}], lr=lr, weight_decay=weight_decay)

net = train(net, train_dataloader, valid_dataloader, criterion, optimizer, None, epochs, device)

## Full Training <a name='training'></a>
After assessing the model performance on the Validation set, we want to re-train the model on the full Training + Validation data to squeeze every performance left before submitting our results to Kaggle. As a general rule, the more data we train on, the better the results will be.

In [None]:
lr, weight_decay, epochs = 1e-5, 5e-4, 20

net = get_net().to(device)

criterion = torch.nn.CrossEntropyLoss()

params_1x = [param for name, param in net.named_parameters() if 'fc' not in str(name)]
optimizer = torch.optim.Adam([{'params':params_1x}, {'params': net.fc.parameters(), 'lr': lr*10}], lr=lr, weight_decay=weight_decay)

net = train(net, train_valid_dataloader, None, criterion, optimizer, None, epochs, device)

## Generating Predictions <a name='testing'></a>
After re-training the network on the full labelled dataset, we are ready to score the Test set and submit out results to Kaggle. We iterate over the test_dataloader to get a predicted label for each Test image, and then create a final DataFrame like the one provided in *sampleSubmission.csv* to use on Kaggle.

In [None]:
preds = []

net.eval()
with torch.no_grad():
    for X, _ in test_dataloader:
        X = X.to(device)
        preds.extend(net(X).argmax(dim=1).type(torch.int32).cpu().numpy())

In [None]:
ids = list(range(1, len(test_dataset)+1))
ids.sort(key=lambda x: str(x))

In [None]:
df = pd.DataFrame({'id': ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_dataset.classes[x])
df.to_csv('submission.csv', index=False)