# Ray Train - A Library for Distributed Deep Learning

[Ray Train](https://docs.ray.io/en/latest/train/train.html) is a lightweight library for distributed deep learning. It provides thin wrappers around [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), and [Horvod](https://horovod.ai/) native modules for data parallel training.

> **NOTE**: Ray SGD is renamed to Ray Train

## PyTorch Fashion MNIST for Distributed Training

<img src="images/fashion-mnist-sprite.jpeg" width="70%" height="60%"> 

We will use Ray Train to distribute our training using couple of models and evaluating which of the two provides us
the best accuracy and a minimal loss. 

As excercise, you can try to further investigate how you improve the modelâ€”via regularization techniques, using CNN layers, trying different loss functions.

The steps we will follow are no different (may be slight variation but the essence is the same) from the previous notenbooks.

So let's go!

First, do the necessary imports, as before.

In [1]:
import os
from typing import Dict

import torch
import torch.nn.functional as F

import ray
import ray.train as train
from ray.train.trainer import Trainer
from ray.train.callbacks import JsonLoggerCallback, TBXLoggerCallback
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster
import ray 

setup_ray_cluster(
  num_worker_nodes=2,
  num_cpus_per_node=4,
  collect_log_to_path="/dbfs/path/to/ray_collected_logs"
)
ray.init()

### Step 1: Download Train and test datasets 

In [2]:
training_data = datasets.FashionMNIST(
    root="~/data",
    train=True,
    download=True,
    transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="~/data",
    train=False,
    download=True,
    transform=ToTensor(),
)

## Step 2: Define a Neural Network Models. 

This is a quite simple NN model

In [3]:
# Define model-1
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, 10), nn.ReLU())

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

Define a deeper NN model archiecture with dropouts

<img src="https://miro.medium.com/max/1400/1*2SHOuTUK51_Up3D9JMAplA.png" width="70%" height="50%">

[source](https://medium.com/@aaysbt/fashion-mnist-data-training-using-pytorch-7f6ad71e96f4)

In [4]:
# Define model-2
class Classifier(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(784, 120)
    self.fc2 = nn.Linear(120, 120)
    self.fc3 = nn.Linear(120,10)
    self.dropout = nn.Dropout(0.2)

  def forward(self,x):
    x = x.view(x.shape[0],-1)
    x = self.dropout(F.relu(self.fc1(x)))
    x = self.dropout(F.relu(self.fc2(x)))
    x = F.log_softmax(self.fc3(x), dim=1)
    return x

In [5]:
# Define accuracy function
def accuracy_fn(y_pred, y_true):
    n_correct = torch.eq(y_pred, y_true).sum().item()
    acc = (n_correct / len(y_pred)) * 100
    return acc

### Step 3: Define per epoch training and validation functinos

In [6]:
def train_epoch(dataloader, model, loss_fn, optimizer, epoch):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [7]:
def validate_epoch(dataloader, model, loss_fn, epoch):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct, acc =  0, 0, 0.0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            predictions = pred.max(dim=1)[1]
            acc += accuracy_fn(predictions, y)
    test_loss /= num_batches
    acc /= num_batches
    correct /= size
    if epoch > 0 and epoch % 50 == 0:
        print(f"Epoc: {epoch}, Avg validation loss: {test_loss:.2f}, Avg validation accuracy: {acc:.2f}%") 
        print("--" * 40)
    return test_loss

### Step 4: Define Ray Train Training function
This function will be passed to `train.run(...)`

In [8]:
def train_func(config: Dict):
    batch_size = config.get("batch_size", 64) 
    lr = config.get('lr', 1e-3)
    epochs = config.get("epochs", 20)
    momentum = config.get("momentum", 0.9)
    model_type = config.get('model_type', None)
    loss_fn = config.get("loss_fn", nn.NLLLoss())

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    # Prepare to use Ray integrated wrappers around PyTorch's Dataloaders
    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.

    model = Classifier() if model_type else NeuralNetwork()
    # Prepare to use Ray integrated wrappers around PyTorch's model
    model = train.torch.prepare_model(model)
    
    # Get or objective loss function
    loss_fn = config.get("loss_fn", nn.NLLLoss())

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    loss_results = []

    for e in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer, e)
        loss = validate_epoch(test_dataloader, model, loss_fn, e)
        train.report(loss=loss)
        loss_results.append(loss)

    return loss_results

### Step 5: Wrap our Trainer around a main driver function

In [9]:
def train_fashion_mnist(num_workers=12, use_gpu=False):
    trainer = Trainer(
        backend="torch", num_workers=num_workers, use_gpu=use_gpu)
    trainer.start()
    result = trainer.run(
        train_func=train_func,
        config={
            "lr": 1e-3,
            "batch_size": 128,
            "epochs": 150,
            "momentum": 0.9,
            "model_type": 0,                     # change to 1 for second NN model
            "loss_fn": nn.CrossEntropyLoss()     # change to nn.nn.NLLLoss() 
        },
        callbacks=[JsonLoggerCallback(), TBXLoggerCallback()])
    trainer.shutdown() 
    return result

### Step 6: Define some parallelism parameters 
And a URL to connect to a Ray Cluster if running on Anysacle

In [10]:
number_of_workers = 8
use_gpu = False                              # change to True if using a Ray cluster with GPUs
address = "anyscale://ray_train_ddp_cluster" # use your anyscale cluster here

### Step 6: Connect to Ray cluster

In [15]:
CONNECT_TO_ANYSCALE=False
if ray.is_initialized:
    ray.shutdown()
    if CONNECT_TO_ANYSCALE:
        ray.init(address)
    else:
        ray.init(ignore_reinit_error=True)

2022-03-16 16:30:27,154	INFO services.py:1412 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8268[39m[22m


### Step 7: Run the main Trainer driver

In [12]:
%%time
results = train_fashion_mnist(num_workers=number_of_workers, use_gpu=use_gpu)

2022-03-16 16:19:11,394	INFO trainer.py:199 -- Trainer logs will be logged in: /Users/jules/ray_results/train_2022-03-16_16-19-11
[2m[36m(BaseWorkerMixin pid=61696)[0m 2022-03-16 16:19:13,471	INFO torch.py:66 -- Setting up process group for: env:// [rank=0, world_size=8]
[2m[36m(BaseWorkerMixin pid=61697)[0m 2022-03-16 16:19:13,458	INFO torch.py:66 -- Setting up process group for: env:// [rank=1, world_size=8]
[2m[36m(BaseWorkerMixin pid=61699)[0m 2022-03-16 16:19:13,466	INFO torch.py:66 -- Setting up process group for: env:// [rank=2, world_size=8]
[2m[36m(BaseWorkerMixin pid=61690)[0m 2022-03-16 16:19:13,473	INFO torch.py:66 -- Setting up process group for: env:// [rank=5, world_size=8]
[2m[36m(BaseWorkerMixin pid=61692)[0m 2022-03-16 16:19:13,481	INFO torch.py:66 -- Setting up process group for: env:// [rank=4, world_size=8]
[2m[36m(BaseWorkerMixin pid=61691)[0m 2022-03-16 16:19:13,465	INFO torch.py:66 -- Setting up process group for: env:// [rank=3, world_size=8]


[2m[36m(BaseWorkerMixin pid=61696)[0m Epoc: 50, Avg validation loss: 1.20, Avg validation accuracy: 58.51%
[2m[36m(BaseWorkerMixin pid=61696)[0m --------------------------------------------------------------------------------
[2m[36m(BaseWorkerMixin pid=61697)[0m Epoc: 50, Avg validation loss: 1.29, Avg validation accuracy: 55.41%
[2m[36m(BaseWorkerMixin pid=61697)[0m --------------------------------------------------------------------------------
[2m[36m(BaseWorkerMixin pid=61699)[0m Epoc: 50, Avg validation loss: 1.26, Avg validation accuracy: 56.29%
[2m[36m(BaseWorkerMixin pid=61699)[0m --------------------------------------------------------------------------------
[2m[36m(BaseWorkerMixin pid=61694)[0m Epoc: 50, Avg validation loss: 1.23, Avg validation accuracy: 56.61%
[2m[36m(BaseWorkerMixin pid=61694)[0m --------------------------------------------------------------------------------
[2m[36m(BaseWorkerMixin pid=61691)[0m Epoc: 50, Avg validation loss: 

### Step 8: Observe metrics in Tensorboard 

Subsitute your path `train_path` printed in the cell above

In [None]:
!tensorboard --logdir ~/ray_results/<train_path>

In [16]:
shutdown_ray_cluster()

### Excercises

Have a go at this in your spare time and observe the results:

 1. Change the learning rate and batch size in `config`
 2. Try chaning the number of workers to 1/2 number of cores on your localhost or laptop
 3. Change the `batch_size` and `epochs`
 4. Try the second model by chaninge the `mode_type` in `config` to 1
 5. Did it improve the accuracy or minimize the loss?
 6. Can you try some deep learning regularization techniques to bring the loss down?
 7. Change a the loss function and test if that help
