In [21]:
%%writefile main.py

import os
import sys
from time import time_ns

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

from tqdm import tqdm

BATCH_SIZE = 512
LR = 0.001
NUM_EPOCHS = 100//2

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)
        self.relu2 = nn.ReLU()
        self.net3 = nn.Linear(5, 1)

    def forward(self, x):
        return self.net3(self.relu2(self.net2(self.relu(self.net1(x)))))
    
    
class ToyDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self,):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]



def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8080'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    
def cleanup():
    dist.destroy_process_group()

    
def train(model, optimizer, train_loader):
    model.train()
    for (x, y) in train_loader:
        x = x.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True)

        optimizer.zero_grad()

        out = model(x)
        loss_fn = nn.MSELoss()
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()

        
def test(model, val_loader):
    model.eval()
    val_loss = 0
    with tqdm(total=len(val_loader.dataset)) as progress_bar:
        with torch.no_grad():
            for batch_idx, (x, y) in enumerate(val_loader):
                x = x.cuda(non_blocking=True)
                y = y.cuda(non_blocking=True)

                out = model(x)
                
                loss_fn = nn.MSELoss()
                loss = loss_fn(out, y)
                val_loss += loss.item()
                progress_bar.update(x.size(0))
            
            val_loss /= batch_idx
    
    return val_loss

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)
    
    torch.cuda.set_device(rank)
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE,
                              shuffle=False, num_workers=1, pin_memory=True, sampler=train_sampler)
    
    val_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                            shuffle=False, num_workers=2, pin_memory=True)

    
    
    model = ToyModel().cuda(rank)
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, nesterov=True)
    model = DDP(model, device_ids=[rank])
    cudnn.benchmark = True
    
    for epoch in range(NUM_EPOCHS):
        t0 = time_ns()

        train(model, optimizer, train_loader)

        t1 = time_ns()
        delta = (t1 - t0) / (10 ** 9)
        print(f"Device {rank} - Train time: {delta} sec")
        
        if rank == 0:
            loss = test(model, val_loader)
            print(f"Loss: {loss}%")
        

    cleanup()

def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)
    
    
    
X = torch.randn(10000, 10)
w = torch.randint(10, (10, 1), dtype=torch.float32)
b = torch.randint(10, (1, 1), dtype=torch.float32)
eps = torch.randn((10000, 1))
y = X @ w + b + eps

train_size = 0.8
train_ids = int(X.shape[0] * train_size)

X_train, X_test = X[:train_ids], X[train_ids:]
y_train, y_test = y[:train_ids], y[train_ids:]

train_dataset = ToyDataset(X_train, y_train)
test_dataset = ToyDataset(X_test, y_test)
    
if __name__ == '__main__':
    run_demo(demo_basic,
             2)

Overwriting main.py


In [22]:
!CUDA_VISIBLE_DEVICES=0,1 python main.py

Running basic DDP example on rank 1.
[W socket.cpp:601] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).
Running basic DDP example on rank 0.
Device 1 - Train time: 5.32592719 sec
Device 0 - Train time: 5.339075536 sec
100%|██████████████████████████████████████| 2000/2000 [00:04<00:00, 450.13it/s]
Loss: 506.2530822753906%
Device 1 - Train time: 6.688313583 sec
Device 0 - Train time: 2.235569794 sec
100%|██████████████████████████████████████| 2000/2000 [00:04<00:00, 425.04it/s]
Loss: 488.2272237141927%
Device 0 - Train time: 2.21684737 sec
  0%|                                                  | 0/2000 [00:00<?, ?it/s]Device 1 - Train time: 6.965490918 sec
100%|██████████████████████████████████████| 2000/2000 [00:04<00:00, 423.90it/s]
Loss: 339.5033721923828%
Device 1 - Train time: 6.92169699 sec
Device 0 - Train time: 2.246385183 sec
100%|██████████████████████████████████████| 2000/2000 [00:04<00:00, 439.07it/s]
Loss