References 
- [Ray Train](https://docs.ray.io/en/latest/train/train.html#)
- [Tensorboard & Pytorch](https://pytorch.org/docs/stable/tensorboard.html)

Now let’s convert this to a distributed multi-worker training function!

We keep the model unchanged.

In [None]:
import torch
import torch.nn as nn

num_samples = 20
input_size = 10
layer_size = 15
output_size = 5

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(input_size, layer_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(layer_size, output_size)

    def forward(self, input):
        return self.layer2(self.relu(self.layer1(input)))

# In this example we use a randomly generated dataset.
input = torch.randn(num_samples, input_size)
labels = torch.randn(num_samples, output_size)


First, update the training function code to use PyTorch’s **DistributedDataParallel**. With Ray Train, you just pass in your distributed data parallel code as as you would normally run it with torch.distributed.launch.

In [None]:
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel

# Writer will output to ./runs/ directory by default
writer = SummaryWriter()


def train_func():
    num_epochs = 3
    model = NeuralNetwork()
    # Add graph to tensorboard, default goto ./runs
    # writer.add_graph(model, input)
    model = DistributedDataParallel(model)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(num_epochs):
        output = model(input)
        loss = loss_fn(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"epoch: {epoch}, loss: {loss.item()}")

Then, instantiate a Trainer that uses a "torch" backend with 4 workers, and use it to run the new training function!

In [None]:
from ray.train import Trainer

trainer = Trainer(backend="torch", num_workers=4)
trainer.start()
results = trainer.run(train_func)
trainer.shutdown()