In [22]:
# References:
# https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
# https://discuss.pytorch.org/t/help-with-ddp-in-kaggle-notebook/213369

In [23]:
%%writefile demo_distributed_training.py
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)
    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank: {rank}.")
    setup(rank, world_size)
    model = Model().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=1e-3)
    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    cleanup()
    print(f"Finished running basic DDP example on rank: {rank}.")

def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn, args = (world_size,), nprocs = world_size, join = True)

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    mp.set_start_method('spawn', force=True) 
    run_demo(demo_basic, world_size)

Overwriting demo_distributed_training.py


In [24]:
!python demo_distributed_training.py

Running basic DDP example on rank: 0.
Running basic DDP example on rank: 1.
Finished running basic DDP example on rank: 1.
Finished running basic DDP example on rank: 0.
