# Pytorch Distributed Data Parallel
- Playing around with functionality, addressing each topic of interest separately
- By topic of interest I mean things that either werent clear to me or pain points when using ddp in my workflow
- Some essential things are not that straight forward in ddp: validation, logging, returning stuff
- Running DDP in Jupyter is not easy. The workaround I'm using is to run the program in a separate process altogether. For this I have to write all imports and helper functions to a .py file. Sounds complicated but it can be neatly done with cell magics. 
- Note: DDP works by running the entire spawned function in separate processes, each with access to a different GPU. The only way we can control what each process executes is through the `rank` variable, which is passed as the first argument. Model parameters are read-only shared between processes. Processes are synced under the hood when .backward is triggered. We can sync and communicate between procs manually using torch.distributed.

### Contents:   
0. Jupyter DDP Cell Magic: to elegantly run DDP in Jupyter  
1. Helper functions: functions used in all other examples
2. Quick Train: to test if bare ddp works  
3. Sampling: How distributed sampling works, looking at indices  
4. Synchronisation: Looking at the timestamps of processes with or without sync  
5. Validation and Logging: How to validate using ddp and informally handle logs  
6. Passing Args: Passing parameters through the spawner without being awkward  
7. Checkpointing and Resuming: Saving and Loading DDP models workaround  
8. Returns: Being able to receive ddp process returns  
9. Summary: Sample script for ddp training with some generic functions solving above issues  

In [1]:
import torch
%config Completer.use_jedi = False

print("Available GPUs:", torch.cuda.device_count())

Available GPUs: 2


# 0. Jupyter DDP Cell Magic
- Jupyter doesnt like DDP, breaks if run directly in a cell or even with %run
- Need to write the executable code to a file and run it in a separate process

In [2]:
import tempfile
import subprocess
from IPython.core.magic import register_cell_magic

@register_cell_magic
def prun(line, cell):
    """
    Run a cell as a different python process and print outputs.
    Ipykernel messes with concurrency so running multiproc in jupyter can be tricky.
    Saves the cell as a temp file and runs python on it with popen.
    Note: the data streams printed out are sorted based on process, not actual time.
    :param line: path to where to create temp file, defaults to /tmp
    """
    with tempfile.NamedTemporaryFile(dir=line) as tmp:
        tmp.write(cell.encode())
        tmp.flush()
        process = subprocess.Popen(["python", tmp.name], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        while True:
            line = process.stdout.readline().decode("utf-8")
            if line != "":
                print(line)
            if process.poll() is not None:
                break

# 1. Helper Functions

In [3]:
%%writefile my_ddp.py

import os
import time
import torch
from torch import nn
from datetime import datetime
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms as T
from torchvision.datasets import CIFAR10
from torchvision.models import vgg11_bn
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def get_loader(rank=0, batch_size=128, train=True, distributed=True):
    norm = T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    transform = T.Compose([            
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            norm,
    ]) if train else T.Compose([T.ToTensor(), norm])
    ds = CIFAR10(root="/data/datasets/CIFAR10", train=train, transform=transform)
    sampler = DistributedSampler(ds, num_replicas=2, rank=rank) if distributed else None
    dataloader = DataLoader(ds, batch_size=batch_size, sampler=sampler)
    return dataloader


def get_model(rank):
    model = torch.nn.Sequential(*vgg11_bn().features, nn.Flatten(), nn.Linear(512, 10))
    model = model.to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)
    return model


def cleanup():
    dist.destroy_process_group()


def run_ddp(trainer, *args):
    world_size = torch.cuda.device_count()
    mp.spawn(trainer, args=args, nprocs=world_size)


Overwriting my_ddp.py


# 2. Quick Test Train
- Just to see if it works
- Print the device of each model

In [38]:
%%prun .

from my_ddp import *


def train(rank, epochs=2):
    setup(rank, world_size=torch.cuda.device_count())
    model = get_model(rank)
    train_loader = get_loader(rank, train=True, distributed=True)
    optimiser = torch.optim.Adam(model.parameters(), lr=5e-4)
    print("Model on device:", model.module[0].weight.device)
    for epoch in range(epochs):
        for X, y in train_loader:
            X, y = X.to(rank), y.to(rank)
            out = model(X)
            loss = nn.functional.cross_entropy(out, y)
            
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

if __name__ == '__main__':
    run_ddp(train)

Model on device: cuda:1

Model on device: cuda:0



# 3. Sampling
- Visualise the sampling of the dataset using DistributedSampler
- Use a dummy tensordataset with 20 data points and print their indices
- Without the distributed sampler each process just goes over the entire dataset

In [85]:
%%prun .

from my_ddp import *


def train(rank, dist_sampling=True):
    setup(rank, world_size=torch.cuda.device_count())
    
    ds = TensorDataset(torch.tensor(range(20)))
    sampler = DistributedSampler(ds, num_replicas=2, rank=rank) if dist_sampling else None
    loader = DataLoader(ds, batch_size=10, sampler=sampler)
    
    for batch, (X, ) in enumerate(loader):
        print(f"Rank: {rank}, Batch: {batch}, Indices:{X.sort()[0].numpy()}")

if __name__ == '__main__':
    print("With distributed Sampling")
    run_ddp(train, True)
    print("")
    
    print("Without distributed Sampling")
    run_ddp(train, False)

With distributed Sampling

Rank: 1, Batch: 0, Indices:[ 0  2  5  6  8 12 14 17 18 19]

Rank: 0, Batch: 0, Indices:[ 1  3  4  7  9 10 11 13 15 16]



Without distributed Sampling

Rank: 1, Batch: 0, Indices:[0 1 2 3 4 5 6 7 8 9]

Rank: 1, Batch: 1, Indices:[10 11 12 13 14 15 16 17 18 19]

Rank: 0, Batch: 0, Indices:[0 1 2 3 4 5 6 7 8 9]

Rank: 0, Batch: 1, Indices:[10 11 12 13 14 15 16 17 18 19]



# 4. Synchronisation
- For some reason the processes run sequentially if no sync trigger?
- Both dist.barrier or loss.backward() trigger syncs.
- Backward doesnt seem to work without zero_grad?

In [7]:
%%prun .

from my_ddp import *


def train(rank, sync=True):
    setup(rank, world_size=torch.cuda.device_count())
    model = get_model(rank)
    
    for batch in range(3):
        model.zero_grad()
        time.sleep(torch.randn(1).abs().item()*2)
        if sync:
            model(torch.randn(1, 3, 32, 32).to(rank)).sum().backward()  # option 1
            #dist.barrier()  # option 2
        print(f"Time: {datetime.now().strftime('%M:%S')}, Rank: {rank}, Batch: {batch}")

if __name__ == '__main__':
    print("With Sync")
    run_ddp(train, True)
    print("")
    
    print("Without Sync")
    run_ddp(train, False)

With Sync

Time: 03:44, Rank: 0, Batch: 0

Time: 03:45, Rank: 0, Batch: 1

Time: 03:47, Rank: 0, Batch: 2

Time: 03:44, Rank: 1, Batch: 0

Time: 03:45, Rank: 1, Batch: 1

Time: 03:47, Rank: 1, Batch: 2



Without Sync

Time: 03:54, Rank: 0, Batch: 0

Time: 03:54, Rank: 0, Batch: 1

Time: 03:55, Rank: 0, Batch: 2

Time: 03:53, Rank: 1, Batch: 0

Time: 03:56, Rank: 1, Batch: 1

Time: 03:58, Rank: 1, Batch: 2



# 5. Validation and Logging
- More concerned with logging validation scores rather than random stuff
- Standard practice seems to validate only on one process, while the other ones are idle :-/
- For larger datasets we would ideally have a dsitributed validation and aggregate scores
- The workflow would be to use a distributed sampler during validation, and reducing the accuracy to only one process wich 
- Note: Reduce only reduces on the destination process, the other ones keep the value of their tensor unchanged
- Note: The default operation for reduce is SUM

In [11]:
%%prun .

from my_ddp import *

# TOY EXAMPLE
def train(rank):
    setup(rank, world_size=torch.cuda.device_count())
    
    time.sleep(torch.randn(1).abs().item()*2)
    score = torch.randint(0, 10, (1,)).to(rank)
    print(f"Time: {datetime.now().strftime('%M:%S')}, Rank: {rank}, Partial: {score.item()}")

    dist.reduce(score, dst=0)

    time.sleep(torch.randn(1).abs().item()*2)
    print(f"Time: {datetime.now().strftime('%M:%S')}, Rank: {rank}, Full: {score.item()}")
    
    if rank == 0:
        print(f"Time: {datetime.now().strftime('%M:%S')}, Rank: {rank}, This gets logged: {score.item()}")

if __name__ == '__main__':
    run_ddp(train)

Time: 43:39, Rank: 1, Partial: 6

Time: 43:40, Rank: 1, Full: 6

Time: 43:36, Rank: 0, Partial: 2

Time: 43:41, Rank: 0, Full: 8

Time: 43:41, Rank: 0, This gets logged: 8



In [29]:
%%prun .

from my_ddp import *

# ACTUAL TRAINING AND VALIDATING
def train(rank):
    setup(rank, world_size=torch.cuda.device_count())
    model = get_model(rank)

    train_loader = get_loader(rank, train=True, distributed=True)
    val_loader = get_loader(rank, train=False, distributed=True)
    optimiser = torch.optim.Adam(model.parameters(), lr=5e-4)
    for epoch in range(10):
        for X, y in train_loader:
            X, y = X.to(rank), y.to(rank)
            out = model(X)
            loss = nn.functional.cross_entropy(out, y)
            
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            
    # logging only once at the end to keep things clean
    model.eval()
    
    n_correct = torch.tensor(0.).to(rank)
    n_obs = torch.tensor(0.).to(rank)
    for X, y in val_loader:
        X, y = X.to(rank), y.to(rank)
        with torch.no_grad():
            out = model(X)
            n_correct +=(out.argmax(dim=-1) == y).float().sum()
            n_obs += len(X)
            
    print(f"Rank: {rank}, n_correct: {n_correct}, n_obs: {n_obs}")
            
    dist.reduce(n_correct, dst=0)
    dist.reduce(n_obs, dst=0)
    
    if rank == 0:
        print(f"Final Acc -  n_correct: {n_correct}, n_obs: {n_obs}, acc: {n_correct / n_obs}")

if __name__ == '__main__':
    run_ddp(train)

Rank: 1, n_correct: 4145.0, n_obs: 5000.0

Rank: 0, n_correct: 4094.0, n_obs: 5000.0

Final Acc -  n_correct: 8239.0, n_obs: 10000.0, acc: 0.8238999843597412



# 6. Passing Args
- Unfortunately spawn accepts only one "args" parameter, which passes the arguments of the executable function
- This means only positional arguments, not named ones can be used - which is muy not elegante for my workflow at least
- This is where some studio engineering comes in handy my hard rocking amigo
- Basically we need a wrapper function that splits the "args" into positinoal and named parameters. Args we pass to spwan will be a tuple of a list and a dict.
- This way we can still define a trainer and call the runner like a normal person

In [51]:
%%prun .

from my_ddp import *


def run_ddp(*args, **kwargs):
    world_size = torch.cuda.device_count()
    mp.spawn(train_wrapper, args=(args, kwargs), nprocs=world_size)

    
def train_wrapper(rank, args: list = None, kwargs: dict = None):
    if args is None:
        args = []
    if kwargs is None:
        kwargs = {}
    train(rank, *args, **kwargs)
    

def train(rank, a=1, b=2, c=3, d=4):
    setup(rank, world_size=torch.cuda.device_count())
    print(f"Rank: {rank}, a={a}, b={b}, c={c}, d={d}")


if __name__ == '__main__':
    print("The defaults are: a=1, b=2, c=3, d=4")
    print("Passed args are: '6, 7, d=9'")
    run_ddp(6, 7, d=9)

The defaults are: a=1, b=2, c=3, d=4

Passed args are: '6, 7, d=9'

Rank: 1, a=6, b=7, c=3, d=9

Rank: 0, a=6, b=7, c=3, d=9



# 7. Checkpointing and Resuming
- Official documentation suggests saving and loading the DDP model directly into GPU
- I think it's just easier to work with the model.module and use regular save/load functions then re-DDP it.

In [57]:
%%prun .

from my_ddp import *


def evaluate(rank, model, loader):
    model.eval()
    n_correct = torch.tensor(0.).to(rank)
    n_obs = torch.tensor(0.).to(rank)
    for X, y in loader:
        X, y = X.to(rank), y.to(rank)
        with torch.no_grad():
            out = model(X)
            n_correct +=(out.argmax(dim=-1) == y).float().sum()
            n_obs += len(X)
            
    #print(f"Rank: {rank}, partial acc: {n_correct / n_obs}")
    dist.reduce(n_correct, dst=0)
    dist.reduce(n_obs, dst=0)
    
    acc = (n_correct / n_obs).detach().item()
    model.train()
    return acc


def train(rank, resume=False):
    setup(rank, world_size=torch.cuda.device_count())
    
    if resume:
        model = torch.load("model.pt")
    else:
        model = torch.nn.Sequential(*vgg11_bn().features, nn.Flatten(), nn.Linear(512, 10))
        
    model = model.to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)

    train_loader = get_loader(rank, train=True, distributed=True)
    val_loader = get_loader(rank, train=False, distributed=True)
    optimiser = torch.optim.Adam(model.parameters(), lr=5e-4)
    for epoch in range(3):
        for X, y in train_loader:
            X, y = X.to(rank), y.to(rank)
            out = model(X)
            loss = nn.functional.cross_entropy(out, y)
            
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            
        val_acc = evaluate(rank, model, val_loader)
    
        if rank == 0:
            print(f"Epoch {epoch}, acc: {val_acc}")
            torch.save(model.module, "model.pt")
        dist.barrier()

if __name__ == '__main__':
    # Train 3 epochs from scratch
    run_ddp(train, False)
    print("Resuming...")
    # Trainer another 3 epochs after resuming
    run_ddp(train, True)

Epoch 0, acc: 0.5824000239372253

Epoch 1, acc: 0.6482999920845032

Epoch 2, acc: 0.7166000008583069

Resuming...

Epoch 0, acc: 0.7473000288009644

Epoch 1, acc: 0.7610999941825867

Epoch 2, acc: 0.794700026512146



# 8. Returns
- I would like the trainer to return results e.g my logs
- Spwan doesnt return anything. They could have done it a la joblib
- To make things easy, I will just use process 0 for checkpointing and logging, so only its return is of interest
- Similar to passing args I think the most elegant solution is to factor out that logic into wrappers
- Will create a temp file and let process 0 serialise the return into it, will load from disk after ddp finishes

In [12]:
%%prun .

import tempfile
from my_ddp import *


def run_ddp():
    world_size = torch.cuda.device_count()
    with tempfile.NamedTemporaryFile() as tmp:
        mp.spawn(train_wrapper, args=(tmp.name, ), nprocs=world_size)
        res = torch.load(tmp.name)
    return res

    
def train_wrapper(rank, tmp_path):
    res = train(rank)
    if rank == 0:
        torch.save(res, tmp_path)
        

def train(rank):
    setup(rank, world_size=torch.cuda.device_count())
    return {"whatever": [rank]}


if __name__ == '__main__':
    res = run_ddp()
    print(f"Return: {res}")
    

Return: {'whatever': [0]}



# 9. Summary
- Putting it all together for some generic functions - run_ddp and proc_wrapper
- They handle passing args and returning stuff
- Using a sample trainer and train func to split model handling from training.

In [34]:
%%prun .

import tempfile
from collections import defaultdict
from my_ddp import *


def run_ddp(trainer, **kwargs):
    """
    Use DDP to execute trainer in parallel, return trainer result from rank 0. 
    """
    world_size = torch.cuda.device_count()
    with tempfile.NamedTemporaryFile() as tmp:
        mp.spawn(proc_wrapper, args=(trainer, tmp.name, kwargs, ), nprocs=world_size)
        res = torch.load(tmp.name)
    return res

    
def proc_wrapper(rank, trainer, tmp_path: str, kwargs: dict = None):
    """
    Wrapper to setup env, serialise trainer returns .
    """
    setup(rank, world_size=torch.cuda.device_count())
    if kwargs is None:
        kwargs = {}
        
    res = trainer(rank, **kwargs)
    if rank == 0:
        torch.save(res, tmp_path)
        

def my_trainer(rank, epochs=10, lr=5e-4, batch_size=128, resume=False):
    """
    Handles model saving/loading, passes to train func.
    """
    if resume:
        model = torch.load("model.pt")
    else:
        model = torch.nn.Sequential(*vgg11_bn().features, nn.Flatten(), nn.Linear(512, 10))
        
    model = model.to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)
    model, logs = train(rank, model, epochs=epochs, lr=lr, batch_size=batch_size)
    
    if rank == 0:
        torch.save(model.module, "model.pt")
    return logs
    
    
def train(rank, model, epochs=10, lr=5e-4, batch_size=128):
    """
    Full training of model, returns logs of training.
    """
    train_loader = get_loader(rank, batch_size=batch_size, train=True, distributed=True)
    val_loader = get_loader(rank, train=False, distributed=True)      
    
    logs = defaultdict(list)
    optimiser = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        for X, y in train_loader:
            X, y = X.to(rank), y.to(rank)
            out = model(X)
            loss = nn.functional.cross_entropy(out, y)

            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

        val_acc = evaluate(rank, model, val_loader) 
        logs["val_acc"].append(val_acc)
    return model, logs
        

def evaluate(rank, model, loader):
    """
    Acc of model over given data loader.
    """
    model.eval()
    n_correct = torch.tensor(0.).to(rank)
    n_obs = torch.tensor(0.).to(rank)
    for X, y in loader:
        X, y = X.to(rank), y.to(rank)
        with torch.no_grad():
            out = model(X)
            n_correct += (out.argmax(dim=-1) == y).float().sum()
            n_obs += len(X)
            
    #print(f"Rank: {rank}, partial acc: {n_correct / n_obs}")
    dist.reduce(n_correct, dst=0)
    dist.reduce(n_obs, dst=0)
    
    acc = (n_correct / n_obs).detach().item()
    model.train()
    return acc
        

if __name__ == '__main__':
    logs = run_ddp(my_trainer, epochs=5)
    print(logs)

defaultdict(<class 'list'>, {'val_acc': [0.5462999939918518, 0.6912999749183655, 0.7278000116348267, 0.7554000020027161, 0.7720999717712402]})

