# 7. Transfer Learning for CNNs

In this notebook we load a small datasets that contains pictures of dolphins and elephants. We classify the images using CNNs and compare two approaches to see what works better:
1. Training a CNN from scratch.
2. Finetuning a pretrained ResNet.

In [1]:
import torch
from torchvision import transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets

torch.manual_seed(0)

<torch._C.Generator at 0x7f8fab59c4b0>

Let's load our data and have a look at the shape of some images:

In [2]:
dataset = datasets.ImageFolder(root='./data/animals')
for i, data in enumerate(dataset):
    print(data)
    if i == 5:
        break

(<PIL.Image.Image image mode=RGB size=300x179 at 0x7F8FA982ADF0>, 0)
(<PIL.Image.Image image mode=RGB size=300x179 at 0x7F8FA982AB20>, 0)
(<PIL.Image.Image image mode=RGB size=300x166 at 0x7F8FA982A760>, 0)
(<PIL.Image.Image image mode=RGB size=300x259 at 0x7F8FA982A580>, 0)
(<PIL.Image.Image image mode=RGB size=300x225 at 0x7F8FA982ADF0>, 0)
(<PIL.Image.Image image mode=RGB size=300x277 at 0x7F8FA982A550>, 0)


We see that the pictures have all `width=300` but a varying height. To use them in transfer learning they need to have the standard shape of size `(224, 224)`, which is the data format of ImageNet (on which most pretrained models are trained on).  

To get them into this shape, we first define a transformation that increases the image height to 224 (this will also increase the width) and then take the 224 pixel center square of the picture:

In [3]:
image_transforms = transforms.Compose([
             transforms.Resize(size=224),
             transforms.CenterCrop(size=224),
             transforms.ToTensor(),
             transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) # standard normalization for transfer learning
    ])

With this transformation, we now load all images from the disk.

Next, we split the data into train and test and define the data loaders that loads the data from disk.

In [5]:
data = datasets.ImageFolder(root='./data/animals', transform=image_transforms)
train_set, test_set = torch.utils.data.random_split(data, [100, 29])

batch_size = 10

trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True)
testloader = torch.utils.data.DataLoader(test_set, batch_size=29,
                                         shuffle=False)

## Tasks:
### Task 1.
Train a CNN from scratch to identify the object on the image (dolphin or elephant). For this, use the same CNN architecture as in cell 7 of the notebook from last week `06_CNNs.ipynb`. To make this work, here are a few things you need to change:
1. You need to change the input size of the fully-connected layer to match the new image dimension.
2. You need to change the output dimension of the fully-connected layer to classify only two classes instead of ten.
3. We now used a dataloader to load the data (see cell above), which allows us to train our model in mini-batches (aka "mini-batch gradient descent"). You can see [here](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#train-the-network), how you can train the network using mini-batches given the `trainloader` from above.  

Train for 20 epochs on the train data and afterwards compute the accuracy on the test data.

### Task 2:
Instead of training a CNN from scratch, we now want to load a pretrained **ResNet18** model and re-train its last layer to do our classifcation task.
PyTorch has a [tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor) on transfer learning, which you can check to see how this works (note: its enough to read the section `ConvNet as fixed feature extractor`).

Train the last layer of the pre-trained RestNet model for 20 epochs and compare the results to task 1.