References 
- [Ray Train](https://docs.ray.io/en/latest/train/train.html#)
- [Tensorboard & Pytorch](https://pytorch.org/docs/stable/tensorboard.html)

Now let’s convert this to a distributed multi-worker training function!

We keep the model unchanged.

In [None]:
import torch
import torch.nn as nn

num_samples = 20
input_size = 10
layer_size = 15
output_size = 5

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(input_size, layer_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(layer_size, output_size)

    def forward(self, input):
        return self.layer2(self.relu(self.layer1(input)))

# In this example we use a randomly generated dataset.
input = torch.randn(num_samples, input_size)
labels = torch.randn(num_samples, output_size)


First, update the training function code to use PyTorch’s **DistributedDataParallel**. With Ray Train, you just pass in your distributed data parallel code as as you would normally run it with torch.distributed.launch.

In [None]:
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel
import ray.train as train

# Writer will output to ./runs/ directory by default
writer = SummaryWriter()


def train_func():
    num_epochs = 3
    model = NeuralNetwork()
    # Add graph to tensorboard, default goto ./runs
    # writer.add_graph(model, input)
    model = DistributedDataParallel(model)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(num_epochs):
        rank = train.world_rank()
        output = model(input)
        loss = loss_fn(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train.report(loss=loss.item())
        train.save_checkpoint(epoch=f"{rank-epoch}", model=model.module)
        print(f"rank: {train.world_rank()}, epoch: {epoch}, loss: {loss.item()}")

Then, instantiate a Trainer that uses a "torch" backend with 4 workers, and use it to run the new training function!

In [None]:
from ray.train import Trainer

logdir = "raylog"
trainer = Trainer(backend="torch", logdir=logdir, num_workers=4)
# trainer.create_logdir(logdir)
# trainer.create_run_dir()
trainer.start()
results = trainer.run(train_func)
trainer.shutdown()

Tensorboard

In [None]:
! tensorboard --logdir=runs

# Why Ray Train is Correct?

In [None]:
from ray import train
from ray.train import Trainer

def train_func(config):
    model = 0 # This should be replaced with a real model.
    for epoch in range(config["num_epochs"]):
        model += epoch
        print(f"rank: {train.world_rank()}; epoch: {epoch}; model: {model}")
        train.save_checkpoint(epoch=epoch, model=model)

trainer = Trainer(backend="torch", num_workers=2, logdir="raylog")
trainer.start()
trainer.run(train_func, config={"num_epochs": 5})
trainer.shutdown()

print(trainer.latest_checkpoint)
# {'epoch': 4, 'model': 10}

In [None]:
from ray import train
from ray.train import Trainer

def train_func(config):
    checkpoint = train.load_checkpoint() or {}
    print(checkpoint)
    # This should be replaced with a real model.
    model = checkpoint.get("model")
    start_epoch = checkpoint.get("epoch") + 1
    for epoch in range(start_epoch, config["num_epochs"]):
        model += epoch
        train.save_checkpoint(epoch=epoch, model=model)

trainer = Trainer(backend="torch", num_workers=1)
trainer.start()
print("Model 1:")
trainer.run(train_func, config={"num_epochs": 5},
            checkpoint="~/ray_results/raylog/run_001/checkpoints/checkpoint_000001")
print("Model 2:")
trainer.run(train_func, config={"num_epochs": 5},
            checkpoint="~/ray_results/raylog/run_001/checkpoints/checkpoint_000002")
print("Model 3:")
trainer.run(train_func, config={"num_epochs": 5},
            checkpoint="~/ray_results/raylog/run_001/checkpoints/checkpoint_000003")
print("Model 4:")
trainer.run(train_func, config={"num_epochs": 5},
            checkpoint="~/ray_results/raylog/run_001/checkpoints/checkpoint_000004")
print("Model 5:")
trainer.run(train_func, config={"num_epochs": 5},
            checkpoint="~/ray_results/raylog/run_001/checkpoints/checkpoint_000005")
trainer.shutdown()

print(f"Final model we are expecting is {trainer.latest_checkpoint}")
# {'epoch': 4, 'model': 10}