# Practical Example - Distributed Calculations
This example runs hyperparameter tuning on a PyTorch model. It runs multiple models simultaneously using Dask.

We will show how to:

* Connect to the Dask cluster
* Load the data from a shared file system
* Run the models
* Use the dashboard and logging capabilities of Dask
* Implement checkpointing and recovery


## Import Libraries

In [None]:
import datetime
import traceback

import dask
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from dask.distributed import Client
from dask_saturn import SaturnCluster
from distributed.client import FIRST_COMPLETED, wait
from distributed.worker import logger
from functions import *
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

## Start the Dask Cluster
Here we are starting the Dask cluster. 
We specify:
* The number of workers
* Their types
* A Client for the cluster

We then wait for the workers to start before continuing. This is usually a good practice.

The output of this block will be a listing of the client properties including the cluster, scheduler, and workers.

In [None]:
n_workers = 3
cluster = SaturnCluster(
    n_workers=n_workers,
    scheduler_size='large',
    worker_size='g4dnxlarge'
)
client = Client(cluster)
client.wait_for_workers(n_workers)
client.restart()

## Upload code files to the cluster
We need to upload the code file so that the cluster knows about it. Otherwise, the cluster will not be able to import the functions.

Because we are just loading one file, we will use the `client.upload_file()` function.

We could also use a worker plugin if we required a directory to be uploaded.
```python
from dask_saturn import RegisterFiles, sync_files

client.register_worker_plugin(RegisterFiles())
sync_files(client, "functions")
client.restart()
```

In [None]:
client.upload_file("functions.py")

## Define the Delayed Function
Input: learning rate
Output: results array

All of the code should look familiar with the exception of the @dask.delayed at the beginning. This decorator tells Dask to parallelize the function.
Outputs are returned and a model is saved every epoch in case of failure.
To load the model, simply use the standard PyTorch loading methods:

``` python
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
```

In [None]:
@dask.delayed
def train_model(lr):
    data_dir = "/home/jovyan/shared/nathan/poc-gsa/datasets/birds"
    model_dir = "/home/jovyan/shared/nathan/poc-gsa/models"
    model_version = 1

    batch_size = 100
    num_epochs = 5

    training_start_time = datetime.datetime.now()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dataloaders, dataset_sizes, classes = load_data(data_dir, batch_size)

    model = define_model(len(classes), True)
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=2)

    results = []

    for epoch in range(num_epochs):
        # Each epoch has a training and validation phase
        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == "train":
                scheduler.step(loss)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            logger.info(
                f"{datetime.datetime.now().isoformat()} - lr {lr} - epoch {epoch} - phase {phase} - loss {epoch_loss} - accuracy {epoch_acc.item()}"
            )

        model_path = f"{model_dir}/model-v{model_version}-lr{lr}-epoch{epoch}-accuracy{round(epoch_acc.item(),2)}"
        new_results = {
            "lr": lr,
            "epoch": epoch,
            "loss": epoch_loss,
            "accuracy": epoch_acc.item(),
            "elapsed_time_sec": (
                datetime.datetime.now() - training_start_time
            ).total_seconds(),
            "model_path": model_path,
        }

        torch.save(
            {
                "epoch": epoch,
                "lr": lr,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": epoch_loss,
                "accuracy": epoch_acc,
            },
            model_path,
        )

        results.append(new_results)
    return results

## Set up the futures to orchestrate the delayed function

To handle errors, we are going to run the delayed function as futures. This allows us to keep track of the function's state and make sure it is computed on the correct resources.

First, we map the inputs to the function to get our first set of futures.
These futures then need to be gathered from the client so that we can compute them.
We finally compute the futures on the worker processes.

Note that we specify the resources used per worker to limit the number of tasks that can be run simultaneously per worker. If this was not specified, Dask would allocate multiple tasks at a time to the workers, resulting in CUDA memory errors on the GPUs. By specifying that the task requires one GPU, Dask knows to send only one task per worker. We could alternatively reduce the batch size to accommodate the memory requirements of the multiple calculations.

>**Note**: The functions will run on the Dask cluster even if the client kernel dies, unless you specify otherwise (`shutdown_on_close` option in cluster setup). It is important that you run `client.restart()` to clear the Dask task graph if you want to discontinue calculations.

In [None]:
learning_rates = np.arange(0.0005, 0.0035, 0.0005)
train_future = client.map(train_model, learning_rates)
futures_gathered = client.gather(train_future)
futures_computed = client.compute(futures_gathered, resources={"gpu": 1})

## Handle Errors
Dask is typically resilient to errors in the task graph (for instance losing a worker), but sometimes errors can occur in your code. To handle these, we create a queue of futures that will output the results or the appropriate error if required. 

Since this function is fairly complicated at first glance, let’s break it down.

``` python
queue = c.compute(results)
futures_to_index = {fut: i for i, fut in enumerate(queue)}
results = [None for x in range(len(queue))]
```

We call compute on results, but since we’re not passing sync=True, we immediately get back futures, which represent the computation, which has not completed yet. We also create a mapping from the future itself, to the _n_th input argument that generated it. Finally, we populate a list of results filled with Nones for now.
```python
while queue:
    result = wait(queue, return_when=FIRST_COMPLETED)
```

Next, we wait for results, and we process them as they come in. When we wait for futures, they are separated into futures that are done, and those that are not_done.
```python
        if future.status == 'finished':
            print(f'finished computation #{index}')
            results[index] = future.result()
```
If the future is finished, then we print that we succeeded, and we store the result.
```python
        else:
            print(f'errored #{index}')
            try:
                future.result()
            except Exception as e:
                results[index] = e
                traceback.print_exc()
```
Otherwise, we store the exception and print the stack trace.
```python
    queue = result.not_done
```
Finally, we set the queue to those futures that have not yet been completed.

In [None]:
queue = client.compute(futures_computed)
futures_to_index = {fut: i for i, fut in enumerate(queue)}
results = [None for x in range(len(queue))]

while queue:
    result = wait(queue, return_when=FIRST_COMPLETED)
    for future in result.done:
        index = futures_to_index[future]
        if future.status == "finished":
            print(f"Finished computation #{index}")
            results[index] = future.result()
        else:
            print(f"Error #{index}")
            try:
                future.result()
            except Exception as e:
                results[index] = e
                traceback.print_exc()
    queue = result.not_done

## Shut Down the Cluster
At this point, all cluster calculations are complete, so we can shut down the cluster.

In [None]:
cluster.close()

## View the Results

In [None]:
results_concatenated = [item for sublist in results for item in sublist]
results_df = pd.DataFrame.from_dict(results_concatenated)
results_df["lr"] = results_df["lr"].astype(str)

sns.set_theme()

sns.relplot(data=results_df, x="epoch", y="loss", col="lr", kind="line")

sns.relplot(data=results_df, x="elapsed_time_sec", y="loss", col="lr", kind="line")

## Load the Best Model
Finally, we want to load the best model and see what our test scores are.


In [None]:
results_df.loc[results_df["loss"].idxmin(), "model_path"]

In [None]:
dataloaders, dataset_sizes, classes = load_data(
    "C:/Users/narob/Documents/Code/SaturnCloud/data/birds", 100
)
model_path = "C:/Users/narob/Documents/Code/SaturnCloud/data/models/model-v1-lr0.003-epoch4-accuracy1.0"

# For use on hosted notebook
#
# dataloaders, dataset_sizes, classes = load_data(
#     "/home/jovyan/shared/nathan/poc-gsa-folder/datasets/birds", 100
# )
# model_path = results_df.loc[results_df["loss"].idxmin(), "model_path"]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = define_model(len(classes), False)
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

checkpoint = torch.load(model_path, map_location=torch.device(device))
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

In [None]:
def get_predictions(model, iterator):
    model.eval()
    labels = []
    probs = []
    with torch.no_grad():
        for (x, y) in iterator:
            x = x.to(device)
            y_pred = model(x)
            y_prob = nn.functional.softmax(y_pred, dim=-1)
            labels.append(y.cpu())
            probs.append(y_prob.cpu())
    labels = torch.cat(labels, dim=0)
    probs = torch.cat(probs, dim=0)
    return labels, probs

In [None]:
labels, probs = get_predictions(model, dataloaders["test"])

In [None]:
pred_labels = torch.argmax(probs, 1)

In [None]:
def plot_confusion_matrix(labels, pred_labels, classes):
    fig = plt.figure(figsize=(50, 50))
    ax = fig.add_subplot(1, 1, 1)
    cm = confusion_matrix(labels, pred_labels)
    cm = ConfusionMatrixDisplay(cm, display_labels=classes)
    cm.plot(values_format="d", cmap="Blues", ax=ax)
    fig.delaxes(fig.axes[1])  # delete colorbar
    plt.xticks(rotation=90)
    plt.xlabel("Predicted Label", fontsize=50)
    plt.ylabel("True Label", fontsize=50)

In [None]:
plot_confusion_matrix(labels, pred_labels, classes)