<img src="img/saturn_logo.png" width="300" />

# Set Up Training

We don't need to run all of Notebook 5 again, we'll just call `setup2.py` in the next chunk to get ourselves back to the right state. This also includes the reindexing work from Notebook 5, and a couple of visualization functions that we'll talk about later.

***
**Note: This notebook assumes you have an S3 bucket where you can store your model performance statistics.**  
If you don't have access to an S3 bucket, but would still like to train your model and review results, please visit [Notebook 6b](06b-transfer-training-local.ipynb) and [Notebook 7](07-learning-results.ipynb) to see detailed examples of how you can do that.
***

In [None]:
%run -i setup2.py

display(HTML(gpu_links))

In [None]:
import torch
from tensorboardX import SummaryWriter

from torch import nn, optim
from torch.nn.parallel import DistributedDataParallel as DDP

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import torch.distributed as dist
from torch.optim import lr_scheduler

In [None]:
client

We're ready to do some learning! 

## Regular Model Details

Aside from the Special Elements noted below, we can write this section essentially the same way we write any other PyTorch training loop. 
* Cross Entropy Loss for our loss function
* SGD (Stochastic Gradient Descent) for our optimizer

We're also using a particular learning rate scheduler called `ReduceLROnPlateau` which leaves our base learning rate alone until the model's efforts hit a plateau and the loss function is no longer decreasing.

We have two stages in this process, as well - training and evaluation. We run the training set completely using batches of 100 before we move to the evaluation step, where we run the eval set completely also using batches of 100.

***

## Special Elements

Most of the training workflow function shown below is pretty standard for users of PyTorch. However, there are a couple of elements that are different.

### Tensorboard Writer

We're using Tensorboard to monitor the model's performance, so we'll create a SummaryWriter object in our training function, and use that to write out statistics and sample image classifications. 


### Worker Rank
```
worker_rank = int(dist.get_rank())
```

This is checking to see which of the workers in the cluster we're on. This way, our results records can tell which worker this performance represents.


### Model to GPU Resources

```
device = torch.device(0)
net = models.resnet50(pretrained=True)
model = net.to(device)
```

As you'll recall from Notebook 4, we need to make sure our model is assigned to a GPU resource- here we do it one time before the training loops begin. We will also assign each image and its label to a GPU resource within the training and evaluation loops - see if you can find this spot, you need to fill in the blanks!


### DDP Wrap
```
device_ids = [0]
model = DDP(model, device_ids=device_ids)
```

And finally, we need to enable the DistributedDataParallel framework. To do this, we are using the `DDP()` wrapper around the model, which is short for the PyTorch function `torch.nn.parallel.DistributedDataParallel`. There is a lot to know about this, but for our purposes the important thing is to understand that this allows the model training to run in parallel on our cluster. https://pytorch.org/docs/stable/notes/ddp.html
***

## Discussing DDP
It's helpful to know what this framework is really doing under the hood.

The official PyTorch documentation tells us this:

>This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension (other objects will be copied once per device). In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.

Clear as mud, right? Let’s try to break it down.

>This container parallelizes the application of the given module

This is just indicating that we’re parallelizing a deep learning workflow — transfer learning in our case.

>by splitting the input across the specified devices by chunking in the batch dimension

Input for a transfer learning workflow is the dataset! Ok, so it is chunking our image batches and that’s what gets to be parallel.

>(other objects will be copied once per device)

Eg, our starting model, if any (Resnet50 for us) doesn’t get broken up at all. Good to know.

>In the forward pass, the module is replicated on each device, and each replica handles a portion of the input.

Ok, so the training task, our module, is replicated on each device. We have multiple copies of the job working simultaneously, and each one gets a chunk of the input images rather than the entire dataset.

>During the backwards pass, gradients from each replica are summed into the original module.

And then each of these duplicate tasks passes the results (the gradients) back home to the original! The learning is happening out in the workers/child processes, and then they all return the results to the original module/training process to be aggregated.

The essential difference with DDP, then, is that it is optimized for multiple machines instead of a single machine with multiple threads. It’s able to communicate across difference machines effectively, so we can use a GPU cluster for our computation.

For a more detailed discussion and more tips about this same workflow, [you can visit our blog to read more!](https://www.saturncloud.io/s/combining-dask-and-pytorch-for-better-faster-transfer-learning/)

***


## Launch Board

### If you save files to S3
Open a terminal on your local machine, run `tensorboard --logdir=s3://[NAMEOFBUCKET]/runs`. Ensure that your AWS creds are in your bash profile/environment.

#### Example of creds you should have
export AWS_SECRET_ACCESS_KEY=`your secret key`   
export AWS_ACCESS_KEY_ID=`your access key id`     
export S3_REGION=us-east-2 `substitute your region`   
export S3_ENDPOINT=https://s3.us-east-2.amazonaws.com `match to your region`   

### If you save files locally

When you are ready to start viewing the board, run this at the terminal inside Jupyter Labs:

`tensorboard --logdir=runs`

Then, in a terminal on your local machine, run: 

`ssh -L 6006:localhost:6006 -i ~/.ssh/PATHTOPRIVATEKEY SSHURLFORJUPYTER`

You'll find the private key path on your local machine, and the SSH URL on the project page for this project. You can change the local port (the first 6006) if you like.

At this stage, you'll likely not have any data, but the board will update itself every thirty seconds.

***


# Training time!
Our whole training process is going to be contained in one function, here named `run_transfer_learning`.



## Modeling Step Functions

Setting these pretty basic steps into functions just helps us ensure perfect parity between our train and evaluation steps.

In [None]:
def iterate_model(inputs, labels, model, device):
    # Pass items to GPU
    inputs = inputs.to(device)
    labels = labels.to(device)

    # Run model iteration
    outputs = model(inputs)

    # Format results
    _, preds = torch.max(outputs, 1)
    perct = [torch.nn.functional.softmax(el, dim=0)[i].item() for i, el in zip(preds, outputs)]
    
    return inputs, labels, outputs, preds, perct
    
def calculate_performance(outputs, labels, preds, criterion):
    loss = criterion(outputs, labels)
    correct = (preds == labels).sum().item()
    
    return loss, correct

In [None]:
s3 = s3fs.S3FileSystem()

with s3.open('s3://saturn-public-data/dogs/imagenet1000_clsidx_to_labels.txt') as f:
    imagenetclasses = [line.strip() for line in f.readlines()]

In [None]:
def run_transfer_learning(bucket, prefix, train_pct, batch_size, n_epochs, base_lr, imagenetclasses, n_workers = 1, subset = False):
    '''Load basic Resnet50, run transfer learning over given epochs.
    Uses dataset from the path given as the pool from which to take the 
    training and evaluation samples.'''
    
    worker_rank = int(dist.get_rank())
    
    # Set results writer
    writer = SummaryWriter(f's3://pytorchtraining/pytorch_workshop/learning_worker{worker_rank}')
    executor = ThreadPoolExecutor(max_workers=64)
    
    # --------- Format model and params --------- #
    device = torch.device("cuda")
    net = models.resnet50(pretrained=True) # True means we start with the imagenet version
    model = net.to(device)
    model = DDP(model)
    
    criterion = nn.CrossEntropyLoss().cuda()    
    lr = base_lr
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', patience = 2)
    
    # --------- Retrieve data for training and eval --------- #
    whole_dataset = prepro_batches(bucket, prefix)
    new_class_to_idx = {x: int(replace_label(x, imagenetclasses)[1]) for x in whole_dataset.classes}
    whole_dataset.class_to_idx = new_class_to_idx
    
    train, val = get_splits_parallel(train_pct, whole_dataset, batch_size=batch_size, subset = subset, workers = n_workers)
    dataloaders = {'train' : train, 'val': val}

    # --------- Start iterations --------- #
    count = 0
    t_count = 0
    
    for epoch in range(n_epochs):
    # --------- Training section --------- #    
        model.train()  # Set model to training mode
        for inputs, labels in dataloaders["train"]:
            dt = datetime.datetime.now().isoformat()

            inputs, labels, outputs, preds, perct = iterate_model(inputs, labels, model, device)
            loss, correct = calculate_performance(outputs, labels, preds, criterion)
            
            # zero the parameter gradients
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            count += 1
            
            # Track statistics
            for param_group in optimizer.param_groups:
                current_lr = param_group['lr']

            if ((count % 3) == 0): 
                future = executor.submit(
                    writer.add_scalars(f'training-count',
                        {'hparam/lr': current_lr, 
                         'hparam/batchsize': batch_size, 
                         'metric/correct': correct,
                         'metric/loss': loss.item(),
                         'hparam/worker': worker_rank},
                         global_step=count
                    )
                )

            # Save a matplotlib figure showing a small sample of actual preds for spot check
            # Functions used here are in setup2.py
            if ((count % 10) == 0):
                future = executor.submit(
                    writer.add_figure(
                        'predictions vs. actuals, training',
                        plot_classes_preds(model, inputs, labels, preds, perct, imagenetclasses),
                        global_step=count
                    )
                )
                
            if (count % 50) == 0 and worker_rank == 0:
                pickle.dump(model.state_dict(), s3.open(f"pytorchtraining/pytorch_workshop/model_epoch{epoch}_iter{count}_{dt}.pkl",'wb'))
                
    # --------- Evaluation section --------- #   
        with torch.no_grad():
            model.eval()  # Set model to evaluation mode
            for inputs_t, labels_t in dataloaders["val"]:
                dt = datetime.datetime.now().isoformat()
                
                inputs_t, labels_t, outputs_t, pred_t, perct_t = iterate_model(inputs_t, labels_t, model, device)
                loss_t, correct_t = calculate_performance(outputs_t, labels_t, pred_t, criterion)
                
                t_count += 1

                # Track statistics
                for param_group in optimizer.param_groups:
                    current_lr = param_group['lr']
                
                if ((t_count % 3) == 0):
                    future = executor.submit(
                        writer.add_scalars(f'eval-count',
                            {'hparam/lr': current_lr,
                             'hparam/batchsize': batch_size,  
                             'metric/correct': correct_t,
                             'metric/loss': loss_t.item(),
                             'hparam/worker': worker_rank},
                             global_step= t_count
                        )
                    )

        scheduler.step(loss)
        
        future = executor.submit(
            writer.add_scalars(f'training-epoch',
                {'hparam/lr': current_lr, 
                 'hparam/batchsize': batch_size, 
                 'metric/correct': correct,
                 'metric/loss': loss.item(),
                 'hparam/worker': worker_rank},
                 global_step=epoch
            )
        )

        future = executor.submit(
            writer.add_scalars(f'eval-epoch',
                {'hparam/lr': current_lr, 
                 'hparam/batchsize': batch_size, 
                 'metric/correct': correct_t,
                 'metric/loss': loss_t.item(),
                 'hparam/worker': worker_rank},
                 global_step=epoch
            )
        )

###### 
Now we've done all the hard work, and just need to run our function! Using `dispatch.run` from `dask-pytorch-ddp`, we pass in the transfer learning function so that it gets distributed correctly across our cluster. This creates futures and starts computing them.

### Define Model Parameters

As with any PyTorch model, you'll want to define the epochs of training you plan to do, the batch size if using batches, and the starting learning rate. We're also able to assign the train/test split here because of how the functions above are written.

(We're using only 6 epochs here to save time, but of course you can increase this.)

In [None]:
import math
import numpy as np
import multiprocessing as mp
import datetime
import json 
import pickle
from concurrent.futures import ThreadPoolExecutor

num_workers = 64
client.restart()

In [None]:
startparams = {'n_epochs': 6, 
                'batch_size': math.ceil(100/3),
                'train_pct': .8,
                'base_lr': 0.01,
                'imagenetclasses':imagenetclasses,
                'subset': True,
                'n_workers': 3} #only necessary if you select subset

## Kick Off Job

### Send Tasks to Workers
 
We talked in Notebook 2 about how we distribute tasks to the workers in our cluster, and now you get to see it firsthand. Inside the `dispatch.run()` function in `dask-pytorch-ddp`, we are actually using the `client.submit()` method to pass tasks to our workers, and collecting these as futures in a list. We can prove this by looking at the results, here named "futures", where we can see they are in fact all pending futures, one for each of the workers in our cluster.

> Why don't we use `.map()` in this function?

Recall that `.map` allows the Cluster to decide where the tasks are completed - it has the ability to choose which worker is assigned any task. That means that we don't have the control we need to ensure that we have one and only one job per GPU. This could be a problem for our methodology because of the use of DDP. 

Instead we use `.submit` and manually assign it to the workers by number. This way, each worker is attacking the same problem - our transfer learning problem - and pursuing a solution simultaneously. We'll have one and only one job per worker.

In [None]:
%%time    
futures = dispatch.run(client, run_transfer_learning, bucket = "saturn-public-data", prefix = "dogs/Images", **startparams)
futures

In [None]:
futures

In [None]:
futures[0].result()

In [None]:
display(HTML(gpu_links))

<img src="https://media.giphy.com/media/VFDeGtRSHswfe/giphy.gif" alt="parallel" style="width: 200px;"/>

Now we let our workers run for awhile. This step will take time, so you may not be able to see the full results during our workshop. See the two links above to view the GPUs efforts as the job runs. This won't hold up your Jupyter environment here, but I promise the cluster is working hard!

***

If you don't have access to an S3 bucket, but would still like to do model performance review, please visit [Notebook 6b](06b-transfer-training-local.ipynb) and [Notebook 7](07-learning-results.ipynb) to see detailed examples of how you can do that.