# Transfer Learning with Resnet

In this notebook we load a small datasets that contains dolphins and elephants. We classify the images using CNNs and compare two approaches and see what worsk better:
1. Training a CNN from sratch against
2. Finetuning a pretrained ResNet.

In [1]:
import torch
from torchvision import transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets

torch.manual_seed(0)

<torch._C.Generator at 0x7fc76ca13710>

Let's load our data a have a look at the shape of some images:

In [2]:
dataset = datasets.ImageFolder(root='./data/anmials')
for i, data in enumerate(dataset):
    print(data)
    if i == 10:
        break

(<PIL.Image.Image image mode=RGB size=300x179 at 0x7FC76ED10160>, 0)
(<PIL.Image.Image image mode=RGB size=300x179 at 0x7FC76C2A2250>, 0)
(<PIL.Image.Image image mode=RGB size=300x166 at 0x7FC76ED10220>, 0)
(<PIL.Image.Image image mode=RGB size=300x259 at 0x7FC76C2A2610>, 0)
(<PIL.Image.Image image mode=RGB size=300x225 at 0x7FC76ED101F0>, 0)
(<PIL.Image.Image image mode=RGB size=300x277 at 0x7FC76EBEDCA0>, 0)
(<PIL.Image.Image image mode=RGB size=300x183 at 0x7FC76ED10160>, 0)
(<PIL.Image.Image image mode=RGB size=300x225 at 0x7FC76EBEDCA0>, 0)
(<PIL.Image.Image image mode=RGB size=300x214 at 0x7FC76ED10280>, 0)
(<PIL.Image.Image image mode=RGB size=300x223 at 0x7FC76EBEDCA0>, 0)
(<PIL.Image.Image image mode=RGB size=300x277 at 0x7FC76C273E20>, 0)


We see that the pictures have all width=300 but a varying height. To use them in transfer learning they need to have the standard shape (224, 224), which is the data format of ImageNet (on which most pretrained models are trained on).  

To get them into this shape, we first increase the height to 224 (this will also increase the height) and then take the 224 square which is center in the middle.

In [3]:
image_transforms = transforms.Compose([
             transforms.Resize(size=224),
             transforms.CenterCrop(size=224),
             transforms.ToTensor(),
             transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) # standard normalization for transfer learning
    ])

In [4]:
data = datasets.ImageFolder(root='./data/anmials', transform=image_transforms)

print(len(data), "data points")

129 data points


Next, we split the data into train and test and define the data loaders.

In [5]:
train_set, test_set = torch.utils.data.random_split(data, [100, 29])

batch_size = 10
trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True)
testloader = torch.utils.data.DataLoader(test_set, batch_size=29,
                                         shuffle=False)

## Tasks:
### Task 1.
Train a CNN from scratch to identify the object on the image (dolphin or elephant). For this, use the same CNN architecture as in notebook `5_CNN_CIFAR10.ipynb` (to make this work, you need to adjust some network parameters for this dataset). Train for 20 epochs on the train data and afterwards compute accuracy on the test data.

### Task 2:
Instead of training a CNN, load a pretrained ResNet18 and only train the last layer (see the lecture slides how that works). Train again for 20 epochs and compare the results.