<img src="img/saturn_logo.png" width="300" />


# Transfer Learning

After Notebooks 3 and 4, we know how to run a very speedy inference job with our parallelization from Dask. But what if we need to train a model? Let's do a transfer learning task to see how that might work.

We are still using Stanford Dogs and starting with Resnet50, and we will use transfer learning to make it perform better at dog image identification.

In order to make this work, we have a few steps to carry out:
* Preprocessing our data appropriately
* Applying infrastructure for parallelizing the learning process
* Running the transfer learning workflow and generating evaluation data


To start, you know the drill by now: get our cluster connected. Fill in the blanks in between `<<< >>>` marks to get the correct code, or click the ellipsis below to check your work.

In [None]:
from dask_saturn import SaturnCluster
from dask.distributed import Client
import s3fs
import re
from torchvision import transforms

cluster = SaturnCluster()
client = Client(cluster)
client.wait_for_workers(3)
client

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

***

## Preprocessing Data

We are using `dask-pytorch-ddp` to handle a lot of the work involved in training across the entire cluster. This will abstract away lots of worker management tasks, and also sets up a tidy infrastructure for managing model output, but if you're interested to learn more about this, we maintain the [codebase and documentation on Github](https://github.com/saturncloud/dask-pytorch).

Because we want to load our images directly from S3, without saving them to memory (and wasting space/time!) we are going to use the `dask-pytorch-ddp` custom class inheriting from the Dataset class called `S3ImageFolder`.

The preprocessing steps are quite short- we want to load images using the class we discussed above, and apply the transformation of our choosing. If you like, you can make the transformations an argument to this function and pass it in.


In [None]:
from dask_pytorch_ddp import results, data, dispatch
from torch.utils.data.sampler import SubsetRandomSampler

In [None]:
def prepro_batches(bucket, prefix):
    '''Initialize the custom Dataset class defined above, apply transformations.'''
    transform = transforms.Compose([
    transforms.Resize(256), 
    transforms.CenterCrop(250), 
    transforms.ToTensor()])
    whole_dataset = data.S3ImageFolder(
        bucket, 
        prefix, 
        transform=transform, 
        anon = True
    )
    return whole_dataset

### Checking Data Labels

Because our task is transfer learning, we're going to be starting with the pretrained Resnet50 model. In order to take full advantage of the training that the model already has, we need to make sure that the label indices on our Stanford Dogs dataset match their equivalents in the Resnet50 label data. (Hint: they aren't going to match, but we'll fix it!)

In [None]:
s3 = s3fs.S3FileSystem()

with s3.open('s3://saturn-public-data/dogs/imagenet1000_clsidx_to_labels.txt') as f:
    imagenetclasses = [line.strip() for line in f.readlines()]

whole_dataset = prepro_batches(bucket = "saturn-public-data", prefix = "dogs/Images")

Any dataset loaded in a PyTorch image folder object will have a few attributes, including `class_to_idx` which returns a dictionary of the class names and their assigned indices. Let's look at the one for our dog images.

In [None]:
list(whole_dataset.class_to_idx.items())[0:5]

So let's look at the Imagenet classes - do they match?

In [None]:
imagenetclasses[0:5]

Well, that's not going to work! Our model thinks 1 = goldfish while our dataset thinks 1 = Japanese Spaniel. Fortunately, this is a pretty easy fix. 

I've created a function called `replace_label()` that checks the labels by text with regex, so that we can be assured that we match them up correctly. This is important, because we can't assume all our dog labels are in exactly the same consecutive order in the imagenet labels.

In [None]:
def replace_label(dataset_label, model_labels):
    label_string = re.search('n[0-9]+-([^/]+)', dataset_label).group(1)
    
    for i in model_labels:
        i = str(i).replace('{', '').replace('}', '')
        model_label_str = re.search('''b["'][0-9]+: ["']([^\/]+)["'],["']''', str(i))
        model_label_idx = re.search('''b["']([0-9]+):''', str(i)).group(1)
        
        if re.search(str(label_string).replace('_', ' '), str(model_label_str).replace('_', ' ')):
            return i, model_label_idx
            break

We can use this function in a couple of lines of list comprehension to create our new `class_to_idx` object. Now we have the indices assigned to match our imagenet dataset!

In [None]:
new_class_to_idx = {x: int(replace_label(x, imagenetclasses)[1]) for x in whole_dataset.classes}

In [None]:
list(new_class_to_idx.items())[0:5]

In [None]:
imagenetclasses[151:156]

Let's also make sure our old and new datasets have the same length, so that nothing got missed.

In [None]:
len(new_class_to_idx.items()) == len(whole_dataset.class_to_idx.items())

***

### Select Training and Evaluation Samples

In order to run our training, we'll create training and evaluation sample sets to use later. These generate DataLoader objects which we can iterate over. We'll use both later to run and monitor our model's learning.

Note the `multiprocessing_context` argument that we are using in the DataLoader objects - this is important! It will allow our large batch jobs to load more than one image simultaneously, and save us a lot of time.

In [None]:
def get_splits_parallel(train_pct, data, batch_size):
    '''Select two samples of data for training and evaluation'''
    classes = data.classes
    train_size = math.floor(len(data) * train_pct)
    indices = list(range(len(data)))
    np.random.shuffle(indices)
    train_idx = indices[:train_size]
    test_idx = indices[train_size:len(data)]

    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(test_idx)
    
    train_loader = torch.utils.data.DataLoader(
        data, 
        sampler=train_sampler,
        batch_size=batch_size,
        num_workers=num_workers,
        multiprocessing_context=mp.get_context('fork'))
    
    test_loader = torch.utils.data.DataLoader(
        data, 
        sampler=train_sampler, 
        batch_size=batch_size, 
        num_workers=num_workers, 
        multiprocessing_context=mp.get_context('fork'))
    
    return train_loader, test_loader

Aside from using our custom data object, this should be very similar to other PyTorch workflows. While I am using the `S3ImageFolder` class here, you definitely don't have to in your own work. Any standard PyTorch data object type should be compatible with the Dask work we're doing next.

Now, it's time for learning, in [Notebook 6a](06a-transfer-training-s3.ipynb)!

<img src="https://media.giphy.com/media/mC7VjtF9sYofs9DUa5/giphy.gif" alt="learn" style="width: 300px;"/>