# Distributed Training for PyTorch Models with Ray on Anyscale
The purpose of this notebook is to demonstrate how developpers can implement distributed training methods on PyTorch models at scale with the open-source framework Ray (running on Anyscale platform).
The primary focus will be utilizing the Ray Train API.

![piyc](https://images.ctfassets.net/xjan103pcp94/QGnrgOJx9rGd8EfSnVehx/e8080f8a43268238ff3557fdbbbadb4a/RayStack.png)

# Steps for this Notebook
##### 1. Prepare Dataset
##### 2. Model Build
##### 3. Distributed Training 

# 01 - Prepare Dataset
We'll be training a simple image classifer on the classic MNIST dataset. The [MNIST dataset](https://production-media.paperswithcode.com/datasets/MNIST-0000000001-2e09631a_09liOmx.jpg) is a source of handwritten digit images consisting of 60000 training samples and 10000 test samples.


![MNIST](https://production-media.paperswithcode.com/datasets/MNIST-0000000001-2e09631a_09liOmx.jpg)

In [4]:
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [5]:
def get_dataloaders(batch_size):
    """generates train and test DataLoaders for model training and evaluation"""
    # define required tranformations for images
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.38081))
    ])
    # download dataset
    train_data = datasets.MNIST('../data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST('../data', train=False, transform=transform)
    
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=True)
    
    return train_loader, test_loader

# 02 - Model Build


In [6]:
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class DigitClassifier(nn.Module):
    """simple classifier model with convolutions"""
    def __init__(self):
        super(DigitClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x)

# 03 - Model Training with Ray Train
[Ray Train](https://docs.ray.io/en/latest/train/train.html) supports all the most popular frameworks for building machine learning models, falling under the following groups. Trainers are meant for ML practitioners to execute training workloads at scale.
* Deep Learning Trainers (PyTorch, TensorFLow, JAX, Horovod)
* Tree Based Trainers (XGBoost, LightGBM)
* General/Other (Scikit-Learn, HuggingFace)

Trainers run training loops on multiple [Ray Actors](https://docs.ray.io/en/latest/ray-core/actors.html#actor-guide) (workers).

We will be executing a distributed training job with Ray Train. Since were training a PyTorch based model we'll utilize the [TorchTrainer](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html) which essentially runs Distributed Data Parallel (DDP) or Fully-Sharded Data Parallel (FSDP) under the hood.

![TT](https://docs.ray.io/en/latest/_images/train.svg)

In [8]:
import torch.optim as optim
import ray
from ray import train
from ray.air import session, Checkpoint, RunConfig
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

### Steps for Running Distributed Training with Ray
1. Define train, evaluation, and job execution functions
2. Wrap the following training components with the appropriate method for enabling distributed execution
    * **Device** ([``ray.train.torch.get_device``](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.get_device.html)): assigns the correct GPU for each process
    * **DataLoader** ([``ray.train.torch.prepare_data_loader``](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_data_loader.html)): moves tensors from CPU to GPU and adds [DistributedSampler](https://pytorch.org/docs/stable/data.html?highlight=distributedsampler#torch.utils.data.distributed.DistributedSampler) to the DataLoaders
    * **Model** ([``ray.train.torch.prepare_model``](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_model.html)): runs DDP/FSDP under the hood
3. Set required configurations for running training job:
    * **RunConfig** ([``ray.air.RunConfig``](https://docs.ray.io/en/latest/ray-air/api/doc/ray.air.RunConfig.html#ray.air.RunConfig)): defines specs for running a given experiment such as:
        * experiment name
        * output storage path
        * stopping conditions
        * checkpoint configurations
        * logging
    * **ScalingConfig** ([``ray.air.config.ScalingConfig``](https://docs.ray.io/en/latest/ray-air/api/doc/ray.air.ScalingConfig.html#ray.air.ScalingConfig)): allows developpers to specify scaling configurations such as:
        * number of workers
        * use GPU or not
        * max CPU usage per node
        * scheduling of workers
    * **TorchConfig** ([``ray.train.torch.TorchConfig``](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchConfig.html#ray-train-torch-torchconfig)): configurations for torch process group:
        * backend ([PyTorch backends](https://pytorch.org/docs/stable/distributed.html))
        * timeout (seconds)
4. Initialize TorchTrainer and run training job
    

In [None]:
def train_job(config):
    """
    function for executing distributing training job
    """
    # prepare train, test sets for distributed execution
    train_loader, test_loader = get_dataloaders(batch_size_per_worker)
    train_loader = train.torch.prepare_data_loader(
        data_loader=train_loader, 
        add_dist_sampler=True, 
        move_to_device=True, 
        auto_transfer=True
    )
    test_loader = train.torch.prepare_data_loader(
        data_loader=test_loader, 
        add_dist_sampler=True, 
        move_to_device=True, 
        auto_transfer=True
    )
    # wrap model to prepare for distributed training
    model = DigitClassifier()
    model = train.torch.prepare_model(
        model=model,
        move_to_device=train.torch.get_device(),
        parallel_strategy=config['parallel_strategy'],    
    )
    # initialize optimizer
    optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'])
    # track training time elapsed
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    # begin training iterations
    for epoch in range(1, config['epochs'] + 1):
        train_model(model, train_loader, optimizer, epoch)
        evaluate_model(model, test_loader)
    end.record()
    print(f'Training Time Elapsed: {start.elapsed_time(end) / 1000}')

In [9]:
def train_model(model, train_loader, optimizer, epoch):
    """executes training iteration for a given epoch"""
    model.train()
    ddp_loss = torch.zeros(2).to()
    for batch_idx, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        # generate predictions for the given batch
        preds = model(x)
        # compute the loss with respect to the target variable
        loss = F.nll_loss(preds, y, reduction='sum')
        loss.backward()
        # update model parameters
        optimizer.step()
        ddp_loss[0] += loss.item()
        ddp_loss[1] += len(x)
    # print and record metrics
    print(f'Epoch: {epoch} \tTrain Loss: {ddp_loss[0] / ddp_loss[1]}')
    session.report(
        metrics={'epoch': epoch, 'train_loss': ddp_loss[0].tolist() / ddp_loss[1].tolist()}, 
        checkpoint=train.torch.TorchCheckpoint.from_state_dict(model.state_dict())
    )

In [28]:
def evaluate_model(model, test_loader):
    """runs model evaluation with test set, records loss"""
    model.eval()
    correct = 0
    ddp_loss = torch.zeros(3).to()
    with torch.no_grad():
        for x, y in test_loader:
            # generate predictions for the given batch
            preds = model(x)
            # sum the batch losses
            ddp_loss[0] += F.nll_loss(preds, y, reduction='sum').item()
            # get index of max log-prob
            pred = preds.argmax(dim=1, keepdim=True)
            ddp_loss[1] += pred.eq(y.view_as(pred)).sum().item()
            ddp_loss[2] += len(x)

    # print and record metrics
    print(f'Test Loss: {ddp_loss[0] / ddp_loss[1]}')
    session.report(
        metrics={'loss': ddp_loss[0].tolist() / ddp_loss[1].tolist()}, 
    )

In [30]:
# set configuration parameters
EXPERIMENT_NAME = "distributed-training-test"
NUM_WORKERS = 2 # number of GPU's
MAX_CPU_ALLOCATION = 0.7 # max fraction of CPU used before spinning up another node
TIMEOUT = 1800
if torch.cuda.is_available():
    BACKEND = 'NCCL'
else:
    BACKEND = 'GLOO'

In [31]:
# setup required configurations for running TorchTrainer
scaling_config = ScalingConfig(
    num_workers=NUM_WORKERS, # number of Ray Actors
    use_gpu=True, # utilizes GPU during session
    _max_cpu_fraction_per_node=MAX_CPU_ALLOCATION # max fraction of CPU's per node for scheduling Actors   
)
run_config = RunConfig(
    name=EXPERIMENT_NAME,
)
torch_config = train.torch.TorchConfig(
    backend=BACKEND,
    timeout_s=TIMEOUT
)

In [56]:
trainer = TorchTrainer(
    train_loop_per_worker=train_job, 
    train_loop_config={
        'batch_size': 32,
        'epochs': 2,
        'learning_rate': 0.001,
        'parallel_strategy': 'ddp' # DDP/FSDP
    },
    torch_config=torch_config,
    scaling_config=scaling_config,
    run_config=run_config
)
results = trainer.fit()
print(results.metrics)
print(results.checkpoint)

0,1
Current time:,2023-08-24 10:51:20
Running for:,00:01:44.13
Memory:,10.5/62.1 GiB

Trial name,status,loc,iter,total time (s),loss
TorchTrainer_974a6_00000,TERMINATED,10.0.11.100:314619,4,24.4592,0.103125


[2m[1m[36m(autoscaler +2h24m9s)[0m Adding 1 node(s) of type worker-node-type-0.




[2m[1m[36m(autoscaler +2h25m15s)[0m Resized to 32 CPUs, 2 GPUs.


[2m[36m(RayTrainWorker pid=314705)[0m 2023-08-24 10:50:56,200	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=2]
100%|██████████| 4542/4542 [00:00<00:00, 68774472.09it/s][32m [repeated 8x across cluster][0m
100%|██████████| 1648877/1648877 [00:00<00:00, 32787303.00it/s][32m [repeated 4x across cluster][0m
[2m[36m(RayTrainWorker pid=2663, ip=10.0.46.94)[0m 2023-08-24 10:50:56,200	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=2]
[2m[36m(RayTrainWorker pid=2663, ip=10.0.46.94)[0m 2023-08-24 10:50:56,200	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=2]
[2m[36m(RayTrainWorker pid=2663, ip=10.0.46.94)[0m 2023-08-24 10:50:56,200	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=2]


[2m[36m(RayTrainWorker pid=2663, ip=10.0.46.94)[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=2605, ip=10.0.10.120)[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=314705)[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=314705)[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=2605, ip=10.0.10.120)[0m Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
[2m[36m(RayTrainWorker pid=314705)[0m 
[2m[36m(RayTrainWorker pid=314705)[0m Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz


[2m[36m(RayTrainWorker pid=314705)[0m 2023-08-24 10:50:58,236	INFO train_loop_utils.py:286 -- Moving model to device: cuda:0
[2m[36m(RayTrainWorker pid=314705)[0m 2023-08-24 10:50:58,237	INFO train_loop_utils.py:346 -- Wrapping provided model in DistributedDataParallel.


[2m[36m(RayTrainWorker pid=314705)[0m Epoch: 1 	Train Loss: 0.7651033401489258
[2m[36m(RayTrainWorker pid=2605, ip=10.0.10.120)[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz[32m [repeated 5x across cluster][0m
[2m[36m(RayTrainWorker pid=314705)[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz[32m [repeated 7x across cluster][0m
[2m[36m(RayTrainWorker pid=314705)[0m Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw[32m [repeated 7x across cluster][0m
[2m[36m(RayTrainWorker pid=314705)[0m [32m [repeated 7x across cluster][0m


Trial name,date,done,experiment_tag,hostname,iterations_since_restore,loss,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id
TorchTrainer_974a6_00000,2023-08-24_10-51-18,True,0,ip-10-0-11-100,4,0.103125,10.0.11.100,314619,24.4592,0.70495,24.4592,1692899478,4,974a6_00000


[2m[36m(RayTrainWorker pid=314705)[0m Test Loss: 0.15693329274654388
[2m[36m(RayTrainWorker pid=314705)[0m Epoch: 2 	Train Loss: 0.3557094931602478[32m [repeated 2x across cluster][0m
[2m[36m(RayTrainWorker pid=2605, ip=10.0.10.120)[0m Epoch: 2 	Train Loss: 0.3557094931602478
[2m[36m(RayTrainWorker pid=2605, ip=10.0.10.120)[0m Test Loss: 0.1238589659333229
[2m[36m(RayTrainWorker pid=2605, ip=10.0.10.120)[0m Training Time Elapsed: 19.07108984375


2023-08-24 10:51:20,666	INFO tune.py:945 -- Total run time: 104.14 seconds (104.13 seconds for the tuning loop).


{'loss': 0.10312475409389528, 'timestamp': 1692899478, 'time_this_iter_s': 0.7049496173858643, 'done': True, 'training_iteration': 4, 'trial_id': '974a6_00000', 'date': '2023-08-24_10-51-18', 'time_total_s': 24.459227800369263, 'pid': 314619, 'hostname': 'ip-10-0-11-100', 'node_ip': '10.0.11.100', 'config': {'train_loop_config': {'batch_size': 32, 'epochs': 2, 'learning_rate': 0.001, 'parallel_strategy': 'ddp'}}, 'time_since_restore': 24.459227800369263, 'iterations_since_restore': 4, 'experiment_tag': '0'}
TorchCheckpoint(local_path=/home/ray/ray_results/distributed-training-test/TorchTrainer_974a6_00000_0_2023-08-24_10-50-51/checkpoint_000001)
[2m[1m[36m(autoscaler +2h30m41s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +2h30m46s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +2h30m51s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +2h30m56s)[0m Removing 1 nodes of type worke

### Generate Predictions with TorchPredictor

In [55]:
results.metrics

{'loss': 0.09873404598020276,
 'timestamp': 1692891306,
 'time_this_iter_s': 1.258814811706543,
 'done': True,
 'training_iteration': 4,
 'trial_id': 'bd0ae_00000',
 'date': '2023-08-24_08-35-07',
 'time_total_s': 25.0549635887146,
 'pid': 4472,
 'hostname': 'ip-10-0-21-71',
 'node_ip': '10.0.21.71',
 'config': {'train_loop_config': {'batch_size': 32,
   'epochs': 2,
   'learning_rate': 0.001,
   'parallel_strategy': 'ddp'}},
 'time_since_restore': 25.0549635887146,
 'iterations_since_restore': 4,
 'experiment_tag': '0'}

[2m[1m[36m(autoscaler +34m43s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +34m53s)[0m Resized to 16 CPUs, 1 GPUs.


In [37]:
model_checkpoint = results.checkpoint.get_model(DigitClassifier())

[2m[1m[36m(autoscaler +15m8s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +15m13s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +15m19s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +15m24s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +15m29s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +15m34s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +15m39s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +15m44s)[0m Removing 1 nodes of type worker-node-type-0 (idle).
[2m[1m[36m(autoscaler +15m54s)[0m Resized to 16 CPUs, 1 GPUs.


In [44]:
predictor = train.torch.TorchPredictor(
    model=model_checkpoint,
    use_gpu=True
)

In [53]:
train_loader, test_loader = get_dataloaders(32)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 85115374.64it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 90534898.22it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 33736384.73it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 17639378.49it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






In [None]:
for x, y in test_loader:
    x, y = x.to(torch.device('cuda')), y.to(torch.device('cuda'))
    pred = predictor.predict(x)
    print(pred)