# Practice 2. PyTorch Distributed and data-parallel training
In this assignment, we'll be going over the distributed part of the PyTorch library. 


This notebook is inspired by an awesome [PyTorch tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html). If you wish to dive deeper into this topic, feel free to read this tutorial, as well as the [docs](https://pytorch.org/docs/stable/distributed.html) themselves.

For now, let's import all required libraries and define a function which will create the process group:  

In [1]:
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import random

def init_process(rank, size, fn, master_port, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = str(master_port)
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)
    
torch.set_num_threads(1)

First, we'll run a very simple function with torch.distributed.barrier. The cell below prints in the first process and then prints in all other processes.

In [2]:
def run(rank, size):
    """ Distributed function to be implemented later. """
    if rank!=0:
        dist.barrier()
    print(f'Started {rank}',flush=True)
    if rank==0:
        dist.barrier()

if __name__ == "__main__":
    size = 4
    processes = []
    port = random.randint(25000, 30000)
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run, port))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Started 0
Started 1Started 2

Started 3


Let's implement a classical ping-pong application with this paradigm. We have two processes, and the goal is to have P1 output 'ping' and P2 output 'pong' without any race conditions.

In [3]:
def run_pingpong(rank, size):
    """ Distributed function to be implemented later. """
    num_iter = 10
    
    for _ in range(num_iter):
        if rank==0:
            dist.barrier()
            print('ping', flush=True)
            dist.barrier()
        if rank==1:
            dist.barrier()
            dist.barrier()
            print('pong', flush=True)


if __name__ == "__main__":
    size = 2
    processes = []
    port = random.randint(25000, 30000)
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run_pingpong, port))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()


ping
pong
ping
pong
ping
pong
ping
pong
ping
pong
ping
pong
ping
pong
ping
pong
ping
pong
ping
pong


## Task 1 (0.1 score)
Generalize the above function to sequential printing for N processes without race conditions in the following order of messages:
```
Process 0
Process 1
Process 2
Process 3
...
Process N-1
```


# Point-to-point communication
The functions below show that it's possible to send data from one process to another with `torch.distributed.send/torch.distributed.recv`. Also, these functions have an immediate (asynchronous) version:

In [4]:
"""Blocking point-to-point communication."""

def run_sendrecv(rank, size):
    tensor = torch.zeros(1)+int(rank==0)
    print('Rank ', rank, ' has data ', tensor[0])
    if rank == 0:
        # Send the tensor to process 1
        dist.send(tensor=tensor, dst=1)
    else:
        # Receive tensor from process 0
        dist.recv(tensor=tensor, src=0)
    print('Rank ', rank, ' has data ', tensor[0])

if __name__ == "__main__":
    size = 2
    processes = []
    port = random.randint(25000, 30000)
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run_sendrecv, port))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Rank Rank   01   has data  has data   tensor(1.)tensor(0.)

Rank  0 Rank  has data  tensor(1.)
 1  has data  tensor(1.)


In [5]:
"""Non-blocking point-to-point communication."""

def run_isendrecv(rank, size):
    tensor = torch.zeros(1)
    req = None
    if rank == 0:
        tensor += 1
        # Send the tensor to process 1
        req = dist.isend(tensor=tensor, dst=1)
        print('Rank 0 started sending')
    else:
        # Receive tensor from process 0
        req = dist.irecv(tensor=tensor, src=0)
        print('Rank 1 started receiving')
    req.wait()
    print('Rank ', rank, ' has data ', tensor[0])
    
if __name__ == "__main__":
    size = 2
    processes = []
    port = random.randint(25000, 30000)
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run_isendrecv, port))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Rank 0 started sendingRank 1 started receiving

Rank Rank   10   has data   has data  tensor(1.)
tensor(1.)


# Collective communication and All-Reduce
Now, let's run a simple All-Reduce example which computes the sum across all workers:

In [6]:
""" All-Reduce example."""

def run_allreduce(rank, size):
    tensor = torch.full((1,),rank)
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    print(f'Rank {rank} has data {tensor[0]}')
    
if __name__ == "__main__":
    size = 10
    processes = []
    port = random.randint(25000, 30000)
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run_allreduce, port))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Rank 0 has data 45Rank 5 has data 45Rank 6 has data 45Rank 2 has data 45Rank 7 has data 45Rank 3 has data 45Rank 4 has data 45Rank 9 has data 45
Rank 8 has data 45







Rank 1 has data 45


In [7]:
%%writefile run_allreduce.py
#!/usr/bin/env python
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

""" All-Reduce example."""

from functools import partial
""" All-Reduce example."""

def run_allreduce(rank, size):
    tensor = torch.ones(1)
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    print('Rank ', rank, ' has data ', tensor[0])
    
def init_process(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)
    
if __name__ == "__main__":
    size = 10
    processes = []
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run_allreduce))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Overwriting run_allreduce.py


In [8]:
! ./run_allreduce.py

Rank  0  has data  tensor(10.)
Rank  9  has data  tensor(10.)
Rank  8  has data  tensor(10.)
Rank  7  has data  tensor(10.)
Rank  6  has data  tensor(10.)
Rank  5  has data  tensor(10.)
Rank  3  has data  tensor(10.)
Rank  4  has data  tensor(10.)
Rank  1  has data  tensor(10.)
Rank  2  has data  tensor(10.)


In [9]:
%%writefile run_allreduce_spawn.py
#!/usr/bin/env python
import os
from functools import partial

import torch
import torch.distributed as dist

""" All-Reduce example."""

def run_allreduce(rank, size):
    tensor = torch.ones(1)
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    print('Rank ', rank, ' has data ', tensor[0])
    
def init_process(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)
    
if __name__ == "__main__":
    size = 10

    fn = partial(init_process, size=size, fn=run_allreduce, backend='gloo')
    torch.multiprocessing.spawn(fn, nprocs=size)

Overwriting run_allreduce_spawn.py


In [10]:
! ./run_allreduce_spawn.py

Rank  8  has data  tensor(10.)
Rank  6  has data  tensor(10.)
Rank  4  has data  tensor(10.)
Rank  0  has data  tensor(10.)
Rank  9  has data  tensor(10.)
Rank  7  has data  tensor(10.)
Rank  3  has data  tensor(10.)
Rank  5  has data  tensor(10.)
Rank  2  has data  tensor(10.)
Rank  1  has data  tensor(10.)


In [11]:
%%writefile custom_allreduce.py
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import random

def init_process(rank, size, fn, master_port, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = str(master_port)
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)

def allreduce(send):
    rank = dist.get_rank()
    size = dist.get_world_size()
    send_buff = send.clone()
    recv_buff = send.clone()
    accum = send.clone()

    left = ((rank - 1) + size) % size
    right = (rank + 1) % size

    for i in range(size - 1):
        if i % 2 == 0:
            # Send send_buff
            send_req = dist.isend(send_buff, right)
            dist.recv(recv_buff, left)
            accum[:] += recv_buff[:]
        else:
            # Send recv_buff
            send_req = dist.isend(recv_buff, right)
            dist.recv(send_buff, left)
            accum[:] += send_buff[:]
        send_req.wait()
    send[:] = accum[:]

def run_allreduce(rank, size):
    """ Simple point-to-point communication. """
    tensor = torch.full((1,), rank, dtype=torch.float)
    allreduce(tensor)
    print('Rank ', rank, ' has data ', tensor[0])
    
if __name__ == "__main__":
    size = 50
    processes = []
    port = random.randint(25000, 30000)
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run_allreduce, port))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Overwriting custom_allreduce.py


In [12]:
sum(range(50))

1225

In [13]:
!python custom_allreduce.py

Rank  22  has data  tensor(1225.)
Rank  27  has data  tensor(1225.)
Rank  28  has data  tensor(1225.)
Rank  29  has data  tensor(1225.)
Rank  40  has data  tensor(1225.)
Rank  20  has data  tensor(1225.)
Rank  38  has data  tensor(1225.)
Rank  45  has data  tensor(1225.)
Rank  23  has data  tensor(1225.)
Rank  42  has data  tensor(1225.)
Rank  32  has data  tensor(1225.)
Rank  24  has data  tensor(1225.)
Rank  46  has data  tensor(1225.)
Rank  21  has data  tensor(1225.)
Rank  30  has data  tensor(1225.)
Rank  7  has data  tensor(1225.)
Rank  9  has data  tensor(1225.)
Rank  19  has data  tensor(1225.)
Rank  49  has data  tensor(1225.)
Rank  16  has data  tensor(1225.)
Rank  39  has data  tensor(1225.)
Rank  15  has data  tensor(1225.)
Rank  47  has data  tensor(1225.)
Rank  18  has data  tensor(1225.)
Rank  31  has data  tensor(1225.)
Rank  2  has data  tensor(1225.)
Rank  5  has data  tensor(1225.)
Rank  8  has data  tensor(1225.)
Rank  0  has data  tensor(1225.)
Rank  14  has data  

# Distributed training

Armed with this simple implementation of AllReduce, we can run multi-process distributed training. For now, let's use the model and the dataset from the official MNIST [example](https://github.com/pytorch/examples/blob/master/mnist/main.py):

In [14]:
%%writefile ddp_example.py
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import random

def init_process(args, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    rank = args.local_rank
    dist.init_process_group(backend, rank=rank)
    size = dist.get_world_size()
    fn(rank, size)
    
torch.set_num_threads(1)

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader 
from torch.utils.data.distributed import DistributedSampler
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
from functools import partial
from argparse import ArgumentParser

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(4608, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        output = self.fc2(x)
        return output

def run_training(rank, size):
    torch.manual_seed(1234)
    dataset = MNIST('./mnist',transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.1307,), (0.3081,))
                             ]))
    loader = DataLoader(dataset, sampler=DistributedSampler(dataset, size, rank), batch_size=16)
    model = Net()
    device = torch.device('cuda',rank)
    model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=[rank])
    optimizer = torch.optim.SGD(model.parameters(),
                          lr=0.01, momentum=0.5)

    num_batches = len(loader)
    steps = 0
    epoch_loss = 0
    for data, target in loader:
        data=data.to(device)
        target=target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = torch.nn.functional.cross_entropy(output, target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
        steps+=1
        if steps % 100 == 0:
            print(f'Rank {dist.get_rank()}, loss: {epoch_loss / num_batches}')
            epoch_loss = 0
        
def parse_args():
    parser=ArgumentParser()
    parser.add_argument('--local_rank', type=int)
    args=parser.parse_args()
    return args
        
if __name__ == "__main__":
    args=parse_args()
    print(args.local_rank)
    init_process(args,fn=run_training, backend='nccl')

Overwriting ddp_example.py


In [15]:
!python -m torch.distributed.launch --nproc_per_node 2 ddp_example.py

*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
1
0
Rank 1, loss: 0.0711741222222646
Rank 0, loss: 0.07069784507751464
Rank 0, loss: 0.03514039077758789
Rank 1, loss: 0.034355521066983544
Rank 1, loss: 0.027626613974571227
Rank 0, loss: 0.028529569840431212
Rank 1, loss: 0.02547050395409266
Rank 0, loss: 0.024831330653031666
Rank 0, loss: 0.02246965136130651
Rank 1, loss: 0.023988804654280344
Rank 0, loss: 0.020719586579004922
Rank 1, loss: 0.020591482353210448
Rank 0, loss: 0.019405658107995986
Rank 1, loss: 0.01995779107809067
Rank 0, loss: 0.019439678021272024
Rank 1, loss: 0.01599694345196088
Rank 1, loss: 0.017450609695911407
Rank 0, loss: 0.016963590412338574
Rank 0, loss: 0.014799549550811449
Rank 1, loss: 0.014166126350561777
Rank 1, lo

## Task 2 (0.6 score)

The above pipeline shows only the basic building blocks of distributed training. Now, let's train something actually interesting! For example, let's take the [CIFAR-100](https://pytorch.org/vision/0.8/datasets.html#torchvision.datasets.CIFAR100) dataset and train a model with **synchronized** Batch Normalization: that is, we average the statistics across workers during each forward pass. Also, implement a version of distributed training which is aware of gradient accumulation: for each batch that doesn't run `optimizer.step`, you can avoid the All-Reduce step altogether. 

(If the resources allow you) Compare the performance (in terms of both speed and quality) of your distributed training pipeline with [the](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel) [primitives](https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html?highlight=syncbatchnorm#torch.nn.SyncBatchNorm) provided by PyTorch.

## Task 3 (0.3 score)

For now, we only aggregate the gradients across different workers during training. But what if we want to run distributed validation on a large dataset as well? In this assignment, you have to implement distributed metric aggregation: shard the dataset across different workers (with `scatter`), compute accuracy for each subset on its respective worker and then average the metric values.