Skip to content

Commit

Permalink
Add Sequence parallel and 2D parallel examples (#1149)
Browse files Browse the repository at this point in the history
* Add Sequence parallel and 2D parallel examples
  • Loading branch information
fduwjj committed May 10, 2023
1 parent c9ef23f commit 6a64939
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 59 deletions.
83 changes: 83 additions & 0 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import argparse

import torch
import torch.multiprocessing as mp

from torch.distributed._tensor import DeviceMesh
from torch.distributed.tensor.parallel import parallelize_module
from utils import cleanup, setup, ToyModel

try:
from torch.distributed.tensor.parallel import (
SequenceParallel
)
SP_AVAILABLE = True
except BaseException as e:
pass


"""
This is the script to test Sequence Parallel(SP) on a toy model in a
Megetron-LM SPMD style. We show an E2E working flow from forward,
backward and optimization.
We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
in between to show an example of sequence parallel, which was proposed in paper:
https://arxiv.org/pdf/2205.05198.pdf.
Like tensor parallel, we parallelize the first linear layer by column
and also parallelize the second linear layer by row. But the input in each rank
now is different so that we need one all-gather for input and one reduce-scatter
in the end of the second linear layer.
"""


def demo_sp(rank, args):
"""
Main body of the demo of a basic version of sequence parallel by using
PyTorch native APIs.
"""
print(f"Running SP example on rank {rank}.")
setup(rank, args.world_size)

# create a sharding plan based on the given world_size.
device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size))

# create model and move it to GPU with id rank
model = ToyModel().cuda(rank)
# Create a optimizer for the parallelized module.
LR = 0.25
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# Parallelize the module based on the given Parallel Style.
model = parallelize_module(model, device_mesh, SequenceParallel())

# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
for _ in range(args.iter_nums):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10).cuda(rank)
output = model(inp)
output.sum().backward()
optimizer.step()

cleanup()


if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
parser = argparse.ArgumentParser()
# This is passed in via cmd
parser.add_argument("--world_size", type=int, default=n_gpus)
parser.add_argument("--iter_nums", type=int, default=10)
args = parser.parse_args()
# The main entry point is called directly without using subprocess
if n_gpus < 2:
print("Requires at least 2 GPUs to run.")
elif not SP_AVAILABLE:
print(
"PyTorch doesn't have Sequence Parallelism available,"
" need nightly build."
)
else:
mp.spawn(demo_sp, args=(args,), nprocs=args.world_size, join=True)
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
import argparse
import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn

TP_AVAILABLE = False
try:
from torch.distributed._tensor import (
DeviceMesh,
)
from torch.distributed.tensor.parallel import (
PairwiseParallel,
parallelize_module,
)
TP_AVAILABLE = True
except BaseException as e:
pass

from torch.distributed._tensor import DeviceMesh
from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module
from utils import cleanup, setup, ToyModel


"""
Expand Down Expand Up @@ -51,41 +40,16 @@
"""


def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

def cleanup():
dist.destroy_process_group()


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.net2 = nn.Linear(32, 5)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))


def demo_tp(rank, args):
"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
print(f"Running basic Megatron style TP example on rank {rank}.")
setup(rank, args.world_size)

# create a sharding plan based on the given world_size.
device_mesh = DeviceMesh(
"cuda",
torch.arange(args.world_size),
)
device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size))

# create model and move it to GPU with id rank
model = ToyModel().cuda(rank)
Expand All @@ -97,7 +61,10 @@ def demo_tp(rank, args):

# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
for _ in range(args.iter_nums):
for i in range(args.iter_nums):
# For TP, input needs to be same across all TP ranks.
# Setting the random seed is to mimic the behavior of dataloader.
torch.manual_seed(i)
inp = torch.rand(20, 10).cuda(rank)
output = model(inp)
output.sum().backward()
Expand All @@ -106,13 +73,6 @@ def demo_tp(rank, args):
cleanup()


def run_demo(demo_fn, args):
mp.spawn(demo_fn,
args=(args,),
nprocs=args.world_size,
join=True)


if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
parser = argparse.ArgumentParser()
Expand All @@ -123,11 +83,5 @@ def run_demo(demo_fn, args):
# The main entry point is called directly without using subprocess
if n_gpus < 2:
print("Requires at least 2 GPUs to run.")
elif not TP_AVAILABLE:
print(
"PyTorch doesn't have Tensor Parallelism available,"
" need nightly build."
)
else:
run_demo(demo_tp, args)

mp.spawn(demo_tp, args=(args,), nprocs=args.world_size, join=True)
126 changes: 126 additions & 0 deletions distributed/tensor_parallelism/two_d_parallel_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import argparse

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.distributed._tensor import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.parallel import (
PairwiseParallel,
parallelize_module,
)
from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp

from utils import cleanup, setup, ToyModel
try:
from torch.distributed.tensor.parallel import (
SequenceParallel
)
SP_AVAILABLE = True
except BaseException as e:
pass


"""
This is the script to test 2D Parallel which combines Tensor/Sequence
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model
in the SPMD style. We show an E2E working flow from forward, backward
and optimization.
We enabled Fully Sharded Data Parallel + Tensor Parallel in
separate parallel dimensions:
Data Parallel across hosts
Tensor Parallel within each host
We use a simple diagram to illustrate below:
======================================================================
------------ ------------ ------------ ------------
| Host 1 | | Host 2 | | | | Host N |
| 8 GPUs | | 8 GPUs | | | | 8 GPUs |
| | | | | ... | | |
| (TP) | | (TP) | | | | (TP) |
|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7|
| | | | | | | .., 8N-1]|
| | | | | | | |
------------ ------------ ------------ ------------
FSDP:
[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]
======================================================================
More details can be seen in the slide:
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
"""


def demo_2d(rank, args):
"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
print(f"Running basic Megatron style TP example on rank {rank}.")
setup(rank, args.world_size)
assert (
args.world_size % args.tp_size == 0
), "World size needs to be divisible by TP size"

# create a sharding plan based on the given world_size.
device_mesh = DeviceMesh(
"cuda", torch.arange(0, args.world_size).view(-1, args.tp_size)
)

# create model and move it to GPU with id rank
model = ToyModel().cuda(rank)
# Create a optimizer for the parallelized module.
LR = 0.25
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# Parallelize the module based on the given Parallel Style.
parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel()
model = parallelize_module(model, device_mesh, parallel_style, tp_mesh_dim=1)

# We need to register hooks for TP + FSDP integration.
assert (
enable_2d_with_fsdp()
), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0"
model = FSDP(model)

# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
for i in range(args.iter_nums):
# For TP, input needs to be same across all TP ranks.
# while for SP, input can be different across all ranks.
# Setting the random seed is to mimic the behavior of dataloader.
dp_rank = (
rank
if args.run_seq_parallel
else dist.get_rank(device_mesh.get_dim_groups()[0])
)
torch.manual_seed(i + dp_rank)
inp = torch.rand(20, 10).cuda(rank)
output = model(inp)
output.sum().backward()
optimizer.step()

cleanup()


if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
parser = argparse.ArgumentParser()
# This is passed in via cmd
parser.add_argument("--world_size", type=int, default=n_gpus)
parser.add_argument("--iter_nums", type=int, default=10)
parser.add_argument("--run_seq_parallel", type=bool, default=False)
parser.add_argument("--tp_size", type=int, default=2)
args = parser.parse_args()
# The main entry point is called directly without using subprocess
if n_gpus < 4:
print("Requires at least 4 GPUs to run.")
elif not SP_AVAILABLE:
print(
"PyTorch doesn't have Sequence Parallelism available,"
" need nightly build."
)
else:
mp.spawn(demo_2d, args=(args,), nprocs=args.world_size, join=True)
31 changes: 31 additions & 0 deletions distributed/tensor_parallelism/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import argparse
import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn


def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


def cleanup():
dist.destroy_process_group()


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.net2 = nn.Linear(32, 5)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))
4 changes: 3 additions & 1 deletion run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ function dcgan() {

function distributed() {
start
python tensor_parallelism/example.py || error "tensor parallel example failed"
python tensor_parallelism/tensor_parallel_example.py || error "tensor parallel example failed"
python tensor_parallelism/sequence_parallel_example.py || error "sequence parallel example failed"
python tensor_parallelism/two_d_parallel_example.py || error "2D parallel example failed"
python ddp/main.py || error "ddp example failed"
}

Expand Down

0 comments on commit 6a64939

Please sign in to comment.