# Transfer Learning with skorch

In this tutorial, you will learn how to train a neutral network using transfer learning with the `skorch` API. Transfer learning uses a pretrained model to initialize a network. This tutorial converts the pure Pytorch approach in [Pytorch's Transfer Learning tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#sphx-glr-beginner-transfer-learning-tutorial-py) into using `skorch`.

We will be using `torchvision` for this tutorial. Instructions on how to install `torchvision` for your platform can be found at https://pytorch.org.

In [1]:
import os
from urllib import request
from zipfile import ZipFile

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, models, transforms

from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler, Checkpoint
from skorch.helper import filtered_optimizer
from skorch.helper import filter_requires_grad

torch.manual_seed(42);

## Preparations

Before we begin, lets download the data needed for this tutorial:

In [3]:
def download_and_extract_data():
    url = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
    data_path = 'datasets/hymenoptera_data.zip'
    
    print("Starting to download and extracting data...")
    if not os.path.exists(data_path):
        # Download
        data = request.urlopen(url, timeout=15).read()
        with open(data_path, 'wb') as f:
            f.write(data)
        
    if not os.path.exists('datasets/hymenoptera_data'):
        with ZipFile(data_path, 'r') as zip_f:
            zip_f.extractall('datasets')
        
    print("Data has been downloaded and extracted!")
        
download_and_extract_data()

Starting to download and extracting data...
Data has been downloaded and extracted!


## The Problem

We are going to train a neutral network to classify **ants** and **bees**. The dataset consist of 120 training images and 75 validiation images for each class. First we create the training and validiation datasets:

In [18]:
data_dir = 'data/hymenoptera_data'
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

train_ds = datasets.ImageFolder(
    os.path.join(data_dir, 'train'), train_transforms)
val_ds = datasets.ImageFolder(
    os.path.join(data_dir, 'val'), val_transforms)

The train dataset includes data augmentation techniques such as cropping to size 224 and horizontal flips.The train and validiation datasets are normalized with mean: `[0.485, 0.456, 0.406]`, and standard deviation: `[0.229, 0.224, 0.225]`. These values are the means and standard deviations of the ImageNet images. We used these values because the pretrained model was trained on ImageNet.

## Loading pretrained model

We use a pretrained `ResNet18` neutral network model with its final layer replaced with a fully connected layer:

In [54]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)

Since we are training a binary classifier, the output of the final fully connected layer has size 2. Next, we freeze all layers except the final layer by setting `requires_grad` to False:

In [55]:
for name, param in model_ft.named_parameters():
    if not name.startswith('fc'):
        param.requires_grad_(False)

## Using skorch's API

In this section, we will create a `skorch.NeuralNetClassifier` to solve our classification problem. 

### Callbacks

First, we create two callbacks:

In [64]:
lrscheduler = LRScheduler(
    policy='StepLR', step_size=7, gamma=0.1)

checkpoint = Checkpoint(
    target='best_model.pt', monitor='valid_acc_best')

callbacks = [lrscheduler, checkpoint]

The `LRScheduler` callback defines a learning rate scheduler that uses `torch.optim.lr_scheduler.StepLR` to scale learning rates by `gamma=0.1` every 7 steps. The `Checkpoint` callback saves the best model by by monitoring the validation accuracy.

### Filtered optimizer

Since we froze some layers in our `Resnet18` neutral network, we need to configure our optimizer to only update gradients in our final fully connected layer. Luckily, `skorch` provides two functions that make this simple:

In [65]:
optimizer = filtered_optimizer(
    optim.SGD, filter_requires_grad
)

### Working with torch.utils.data.Dataset

We have already split our data into training and validation datasets: `train_ds` and `val_ds`. In order for skorch to use these datasets we define a helper function:

In [66]:
def train_valid_split(train, valid, **kwargs):
    return train, valid

This function does not do any processing and returns the two datasets. 

### skorch.NeutralNetClassifier

With all the preparations out of the way, we can now define our `NeutralNetClassifier`:

In [67]:
net = NeuralNetClassifier(
    model_ft, 
    criterion=nn.CrossEntropyLoss,
    lr=0.001,
    batch_size=4,
    max_epochs=25,
    optimizer=optimizer,
    optimizer__momentum=0.9,
    iterator_train__shuffle=True,
    iterator_train__num_workers=4,
    iterator_valid__shuffle=True,
    iterator_valid__num_workers=4,
    train_split=train_valid_split,
    callbacks=callbacks,
#     device='cuda' # uncomment to train on gpu
)

That is quite a few parameters! Lets walk through each one:

1. `model_ft`: Our `ResNet18` neutral network
2. `criterion=nn.CrossEntropyLoss`: loss function
3. `lr`: Initial learning rate
4. `batch_size`: Size of a batch
5. `max_epochs`: Number of epochs to train
6. `optimizer`: Our filtered optimizer
7. `optimizer__momentum`: The initial momentum
8. `iterator_{train,valid}__{shuffle,num_workers}`: Parameters that are passed to the dataloader.
9. `train_split`: Our custom `train_valid_split` function
10. `callbacks`: Our callbacks 
11. `device`: Set to `cuda` to train on gpu.

Now we are ready to train our neutral network:

In [68]:
net.fit(train_ds, val_ds);

Checkpoint! Saving model to best_model.pt.
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.3510[0m       [32m0.9477[0m        [35m0.1671[0m  1.2843
      2        0.4117       0.9477        0.2174  1.2961
      3        0.3629       0.9412        0.1957  1.3684
      4        0.4517       0.9477        0.2357  1.3041
      5        0.4432       0.9346        0.1920  1.3495
      6        0.4343       0.7974        0.6507  1.2952
      7        0.5948       0.9281        0.1940  1.3564
      8        0.3813       0.9412        [35m0.1663[0m  1.3061
      9        [36m0.2655[0m       0.9412        0.1742  1.4253
     10        0.2923       0.9412        [35m0.1584[0m  1.3351
     11        0.3905       0.9412        0.1710  1.3842
     12        [36m0.2401[0m       0.9346        0.1669  1.3937
     13        0.3938       0.9281        0.1752  1.3792
     14        0.3182       0.9346     

The best model is stored at `best_model.pt`, with a validiation accuracy of `0.9608`. 

Congrualations! You now know how to finetune a neutral network using `skorch`. Feel free to explore the other tutorials to learn more about using `skorch`.