# Pytorch Lighting with Ray

The RayPlugin provides `Distributed Data Parallel` training on a Ray cluster. PyTorch DDP is used as the distributed training protocol, and Ray is used to launch and manage the training worker processes.


in this notebook we will reproduce the training in 
https://colab.research.google.com/github/PyTorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-hello-world.ipynb
 

also you can find more detail about DDP in https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

source code of ddp implementation in ray: https://docs.ray.io/en/master/_modules/ray_lightning/ray_ddp.html



In [2]:
import pytorch_lightning as pl
from ray_lightning import RayPlugin

In [3]:
import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 32

In [4]:

class LitMNIST(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

        self.accuracy = Accuracy()

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.accuracy, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

In [5]:
plugin = RayPlugin(num_workers=8, num_cpus_per_worker=1)

# Don't set ``gpus`` in the ``Trainer``.
# The actual number of GPUs is determined by ``num_workers``.

model = LitMNIST()
trainer = Trainer(
    gpus=0,
    max_epochs=3,
    progress_bar_refresh_rate=20,
    plugins=[plugin]  # <- only add this line
)
trainer.fit(model)

2021-12-16 18:21:34,557	INFO services.py:1340 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
[2m[36m(RayExecutor pid=176)[0m initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/8
[2m[36m(RayExecutor pid=180)[0m initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/8
[2m[36m(RayExecutor pid=178)[0m initializing ddp: GLOBAL_RANK: 7, MEMBER: 8/8
[2m[36m(RayExecutor pid=179)[0m initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/8
[2m[36m(RayExecutor pid=177)[0m initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/8
[2m[36m(RayExecutor pid=181)[0m initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/8
[2m[36m(RayExecutor pid=175)[0m initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/8
[2m[36m(RayExecutor pid=182)[0m initializing ddp: GLOBAL_RANK: 6, MEMBER: 7/8
[2m[36m(RayExecutor pid=176)[0m 
[2m[36m(RayExecutor pid=176)[0m   | Name     | Type       | Params
[2m[36m(RayExecut

Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(RayExecutor pid=176)[0m   f"The dataloader, {name}, does not have many workers which may be a bottleneck."


                                                              
Training: -1it [00:00, ?it/s])[0m 
Epoch 0:   0%|          | 0/235 [00:00<00:00, 2460.00it/s]  




Epoch 0:   9%|▊         | 20/235 [00:01<00:12, 17.30it/s, loss=2.19, v_num=6]
Epoch 0:  17%|█▋        | 40/235 [00:02<00:10, 18.46it/s, loss=1.89, v_num=6]
Epoch 0:  26%|██▌       | 60/235 [00:03<00:09, 18.19it/s, loss=1.55, v_num=6]
Epoch 0:  34%|███▍      | 80/235 [00:05<00:09, 16.12it/s, loss=1.23, v_num=6]
Epoch 0:  43%|████▎     | 100/235 [00:06<00:08, 15.11it/s, loss=0.994, v_num=6]
Epoch 0:  51%|█████     | 120/235 [00:07<00:07, 15.52it/s, loss=0.834, v_num=6]
Epoch 0:  60%|█████▉    | 140/235 [00:08<00:05, 15.99it/s, loss=0.775, v_num=6]
Epoch 0:  68%|██████▊   | 160/235 [00:09<00:04, 16.15it/s, loss=0.697, v_num=6]
Epoch 0:  77%|███████▋  | 180/235 [00:11<00:03, 16.34it/s, loss=0.602, v_num=6]
Epoch 0:  85%|████████▌ | 200/235 [00:12<00:02, 16.50it/s, loss=0.559, v_num=6]
Epoch 0:  94%|█████████▎| 220/235 [00:13<00:00, 16.83it/s, loss=0.559, v_num=6]
Validating: 0it [00:00, ?it/s][Am 
Validating:   0%|          | 0/20 [00:00<?, ?it/s][A
[2m[36m(RayExecutor pid=176)[0m 
Va

In [13]:
trainer.test(model)

[2m[36m(RayExecutor pid=3140)[0m initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/8
[2m[36m(RayExecutor pid=3146)[0m initializing ddp: GLOBAL_RANK: 6, MEMBER: 7/8
[2m[36m(RayExecutor pid=3145)[0m initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/8
[2m[36m(RayExecutor pid=3141)[0m initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/8
[2m[36m(RayExecutor pid=3147)[0m initializing ddp: GLOBAL_RANK: 7, MEMBER: 8/8
[2m[36m(RayExecutor pid=3143)[0m initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/8
[2m[36m(RayExecutor pid=3142)[0m initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/8
[2m[36m(RayExecutor pid=3144)[0m initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/8
[2m[36m(RayExecutor pid=3140)[0m   f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Testing: 0it [00:00, ?it/s]140)[0m 
Testing:  50%|█████     | 20/40 [00:00<00:00, 39.71it/s]


[{'val_loss': 0.2514459788799286, 'val_acc': 0.9279999732971191}]

Testing: 100%|██████████| 40/40 [00:01<00:00, 35.69it/s]--------------------------------------------------------------------------------
[2m[36m(RayExecutor pid=3140)[0m DATALOADER:0 TEST RESULTS
[2m[36m(RayExecutor pid=3140)[0m {'val_acc': 0.9279999732971191, 'val_loss': 0.2514459788799286}
[2m[36m(RayExecutor pid=3140)[0m --------------------------------------------------------------------------------
Testing: 100%|██████████| 40/40 [00:01<00:00, 35.42it/s]
