
CSYE7105 Parallel Machine Learning & AI

Intructor: Dr. Handan Liu

Lecture: Data Parallelism and Model Parallelism


In [1]:
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

In [2]:
def setup(rank, size, run, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    run(rank, size)
    
def cleanup():
    dist.destroy_process_group()

In [3]:
class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.\n")

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()

In [4]:
if __name__ == "__main__":
    size = 2
    processes = []
    #mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=setup, args=(rank, size, demo_basic))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Running basic DDP example on rank 0.

Running basic DDP example on rank 1.



In [5]:
def demo_checkpoint(rank, world_size):
    print(f"Running DDP checkpoint example on rank {rank}.\n")

    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])


    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        """All processes should see same parameters as they all start from same
           random parameters and gradients are synchronized in backward passes.
           Therefore, saving it in one process is sufficient."""
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process 0 saves it.
    dist.barrier()
    
    # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)

    loss_fn(outputs, labels).backward()
    optimizer.step()

    """Not necessary to use a dist.barrier() to guard the file deletion below
    as the AllReduce ops in the backward pass of DDP already served as a synchronization."""

    if rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()

In [6]:
if __name__ == "__main__":
    size = 2
    processes = []
    #mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=setup, args=(rank, size, demo_checkpoint))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Running DDP checkpoint example on rank 0.

Running DDP checkpoint example on rank 1.



In [7]:
class ToyMpModel(nn.Module):
    def __init__(self, dev0, dev1):
        super(ToyMpModel, self).__init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10, 10).to(dev0)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to(dev1)

    def forward(self, x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)

In [8]:
def demo_model_parallel(rank, world_size):
    print(f"Running DDP with model parallel example on rank {rank}.\n")
    
    # setup mp_model and devices for this process
    dev0 = rank * 2
    dev1 = rank * 2 + 1
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DDP(mp_model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    # outputs will be on dev1
    outputs = ddp_mp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()

In [9]:
n_gpus = torch.cuda.device_count()
n_gpus

4

In [10]:
if __name__ == "__main__":
    size = 4
    processes = []
    #mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=setup, args=(rank, size, demo_model_parallel))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Running DDP with model parallel example on rank 3.
Running DDP with model parallel example on rank 1.
Running DDP with model parallel example on rank 2.
Running DDP with model parallel example on rank 0.





Process Process-6:
Process Process-8:
Process Process-7:





Process Process-5:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/flyingsky2007/.conda/envs/py2022/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/flyingsky2007/.conda/envs/py2022/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/flyingsky2007/.conda/envs/py2022/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/flyingsky2007/.conda/envs/py2022/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/flyingsky2007/.conda/envs/py2022/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/flyingsky2007/.conda/envs/py2022/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/flyin