# Transfer learning with PyTorch
We're going to train a neural network to classify dogs and cats.

## Init, helpers, utils, ...

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

materialmaterial

In [None]:
import matplotlib.pyplot as plt
from pprint import pprint
import numpy as np
from IPython.core.debugger import set_trace

import utils  # little helpers

In [None]:
# Training helpers
def get_trainable(model_params):
    return (p for p in model_params if p.requires_grad)


def get_frozen(model_params):
    return (p for p in model_params if not p.requires_grad)


def all_trainable(model_params):
    return all(p.requires_grad for p in model_params)


def all_frozen(model_params):
    return all(not p.requires_grad for p in model_params)


def freeze_all(model_params):
    for param in model_params:
        param.requires_grad = False


# list(get_trainable(model.parameters()))
# list(get_frozen(model.parameters()))
# all_trainable(model.parameters())
# all_frozen(model.parameters())

# The Data - DogsCatsDataset

## Transforms

In [None]:
from torchvision import transforms

_image_size = 224
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


train_trans = transforms.Compose([
    transforms.Resize(256),  # some images are pretty small
    transforms.RandomCrop(_image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])
val_trans = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(_image_size),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

## Dataset

In [None]:
from utils import DogsCatsDataset

In [None]:
train_ds = DogsCatsDataset("../data/raw", "sample/train", transform=train_trans)
val_ds = DogsCatsDataset("../data/raw", "sample/valid", transform=val_trans)

batch_size = 2
n_classes = 2

Use the following if you want to use the full dataset:

In [None]:
# train_ds = DogsCatsDataset("../data/raw", "train", transform=train_trans)
# val_ds = DogsCatsDataset("../data/raw", "calid", transform=val_trans)
# batch_size = 512

In [None]:
len(train_ds), len(val_ds)

## DataLoader
Batch loading for datasets with multi-processing and different sample strategies.

In [None]:
from torch.utils.data import DataLoader


train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
)

val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
)

# The Model
PyTorch offers quite a few [pre-trained networks](https://pytorch.org/docs/stable/torchvision/models.html) such as:
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- Inception v3

And there are more available via [pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch):
- NASNet,
- ResNeXt,
- InceptionV4,
- InceptionResnetV2, 
- Xception, 
- DPN,
- ...

We'll use a simple resnet18 model:

In [None]:
from torchvision import models

model = models.resnet18(pretrained=True)

In [None]:
model

In [None]:
import torchsummary

torchsummary.summary(model, (3, 224, 224))

In [None]:
# Freeze all parameters manually
for param in model.parameters():
    param.requires_grad = False

In [None]:
# Or use our convenient functions from before
freeze_all(model.parameters())
assert all_frozen(model.parameters())

Replace the last layer with a linear layer. New layers have `requires_grad = True`.

In [None]:
model.fc = nn.Linear(512, n_classes)

In [None]:
assert not all_frozen(model.parameters())

In [None]:
def get_model(n_classes=2):
    model = models.resnet18(pretrained=True)
    freeze_all(model.parameters())
    model.fc = nn.Linear(512, n_classes)
    model = model.to(DEVICE)
    return model


model = get_model()

# The Loss

In [None]:
criterion = nn.CrossEntropyLoss()

# The Optimizer

In [None]:
optimizer = torch.optim.Adam(
    get_trainable(model.parameters()),
    lr=0.0001,
    # momentum=0.9,
)

# The Train Loop

In [None]:
N_EPOCHS = 2

for epoch in range(N_EPOCHS):
    print(f"Epoch {epoch+1}/{N_EPOCHS}")
    
    # Train
    model.train()  # IMPORTANT
    
    running_loss, correct = 0.0, 0
    for X, y in train_dl:
        X, y = X.to(DEVICE), y.to(DEVICE)
        
        optimizer.zero_grad()
        y_ = model(X)
        loss = criterion(y_, y)
        loss.backward()
        optimizer.step()
        
        # Statistics
        print(f"    batch loss: {loss.item():0.3f}")
        _, y_label_ = torch.max(y_, 1)
        correct += (y_label_ == y).sum().item()
        running_lohss += loss.item() * X.shape[0]
    
    print(f"  Train Loss: {running_loss / len(train_dl.dataset)}")
    print(f"  Train Acc:  {correct / len(train_dl.dataset)}")
    
    
    # Eval
    model.eval()  # IMPORTANT
    
    running_loss, correct = 0.0, 0
    with torch.no_grad():  # IMPORTANT
        for X, y in val_dl:
            X, y = X.to(DEVICE), y.to(DEVICE)
                    
            y_ = model(X)
        
            # Statistics
            _, y_label_ = torch.max(y_, 1)
            correct += (y_label_ == y).sum().item()
            loss = criterion(y_, y)
            running_loss += loss.item() * X.shape[0]
    
    print(f"  Valid Loss: {running_loss / len(val_dl.dataset)}")
    print(f"  Valid Acc:  {correct / len(val_dl.dataset)}")
    print()

# Exercise
- Create your own module which takes any imagenet model (uses it unmodified as backbone) and adds a problem specific head.

In [None]:
class Net(nn.module):
    def __init__(self, backbone: nn.Module, n_classes: int):
        super().__init__()
        # self.backbone
        # self.head = init_head(n_classes)
        
    def forward(self, x):
        # TODO
        return x