Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TP examples to align with tutorials #1243

Merged
merged 1 commit into from Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/main_distributed.yaml
@@ -0,0 +1,38 @@
name: Run Distributed Examples

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
schedule:
# Every day at 3:00am
- cron: '0 3 * * *'


jobs:
test:

runs-on: 4-core-ubuntu-gpu-t4

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install PyTorch
run: |
python -m pip install --upgrade pip
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu118/torch_nightly.html
- name: Run Tests
run: |
./run_distributed_examples.sh "run_all,clean"
- name: Open issue on failure
if: ${{ failure() && github.event_name == 'schedule' }}
uses: rishabhgupta/git-action-issue@v2
with:
token: ${{ secrets.GITHUB_TOKEN }}
title: Daily CI failed
body: Commit ${{ github.sha }} daily scheduled [CI run](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}) failed, please check why
assignees: ''
16 changes: 8 additions & 8 deletions distributed/tensor_parallelism/README.md
@@ -1,14 +1,14 @@
# PyTorch Tensor Parallelism for distributed training
# PyTorch native Tensor Parallel for distributed training

This example demonstrates SPMD Megatron-LM style tensor parallel by using
PyTorch native Tensor Parallelism APIs, which include:
This example demonstrates SPMD Megatron-LM style Tensor Parallel by using
PyTorch native Tensor Parallel APIs, which include:

1. High-level APIs for module-level parallelism with a dummy MLP model.
2. Model agnostic ops for `DistributedTensor`, such as `Linear` and `RELU`.
3. A E2E demo of tensor parallel for a given toy model (Forward/backward + optimization).
1. Simple module-level Tensor Parallelism on a dummy MLP model.
2. Simple module-level Tensor Parallelism with Sequence Parallel inputs/outputs on a dummy MLP model.
3. A E2E demo of Fully Sharded Data Parallel + Tensor Parallel (with Sequence Parallel) on a example Llama2 model.

More details about the design can be found:
https://github.com/pytorch/pytorch/issues/89884
More details about the PyTorch native Tensor Parallel APIs, please see PyTorch docs:
https://pytorch.org/docs/stable/distributed.tensor.parallel.html

```
pip install -r requirements.txt
Expand Down
142 changes: 80 additions & 62 deletions distributed/tensor_parallelism/fsdp_tp_example.py
@@ -1,20 +1,12 @@
import sys
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)

import os
from log_utils import rank_log, get_logger, verify_min_gpu_count


# ---- GPU check ------------
_min_gpu_count = 4

Expand All @@ -23,13 +15,24 @@
sys.exit()
# ---------------------------

from torch.distributed._tensor.device_mesh import init_device_mesh
from llama2_model import Transformer, ModelArgs

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
PrepareModuleInput,
SequenceParallel
)


"""
This is the script to test 2D Parallel which combines Tensor/Sequence
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model
in the SPMD style. We show an E2E working flow from forward, backward
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example
Llama2 model. We show an E2E working flow from forward, backward
and optimization.

We enabled Fully Sharded Data Parallel + Tensor Parallel in
Expand All @@ -53,41 +56,10 @@
[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]
======================================================================

More details can be seen in the slide:
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
More details can be seen in the PyTorch tutorials:
https://pytorch.org/tutorials/intermediate/TP_tutorial.html
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
"""


def find_multiple(n: int, k: int) -> int:
"""function to find resizing multiple for SwiGLU MLP"""
if n % k == 0:
return n
return n + k - (n % k)


class MLP_swiglu(nn.Module):
"""SwiGLU to showcase a Llama style MLP model"""

def __init__(self, mlp_dim: int = 1024) -> None:
super().__init__()
hidden_dim = 4 * mlp_dim
scaled_hidden = int(2 * hidden_dim / 3)
rounded_hidden = find_multiple(scaled_hidden, 256)

self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.in_proj(x)) * self.gate_proj(x)
x = self.out_proj(x)
return x


"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
tp_size = 2
logger = get_logger()

Expand Down Expand Up @@ -120,26 +92,72 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# to mimic the behavior of the dataloader.
dp_rank = dp_mesh.get_local_rank()

# create model and move it to GPU with id rank
_mlp_dim = 1024
base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to("cuda")


# Custom parallelization plan for the swiglu MLP model
custom_tp_model = parallelize_module(
module=base_model_swiglu,
device_mesh=tp_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(),
"gate_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
simple_llama2_config = ModelArgs(dim=256, n_layers=2, n_heads=16, vocab_size=32000)

model = Transformer.from_model_args(simple_llama2_config).to("cuda")

# init model weights
model.init_weights()

# parallelize the first embedding and the last linear out projection
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate()
),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(1), None),
use_local_output=True,
),
}
)

rank_log(_rank, logger, f"Model after parallelization {custom_tp_model=}\n")
for layer_id, transformer_block in enumerate(model.layers):
layer_tp_plan = {
"attention": PrepareModuleInput(
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

# Custom parallelization plan for the model
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan
)

# Init FSDP using the dp device mesh
sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True)
sharded_model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)

rank_log(_rank, logger, f"Model after parallelization {sharded_model=}\n")

# Create an optimizer for the parallelized and sharded model.
lr = 3e-3
Expand All @@ -156,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
for i in range(num_iterations):
# seeding with dp_rank to ensure identical inputs for TP groups
torch.manual_seed(i + dp_rank)
inp = torch.rand(batch_size, _mlp_dim, device="cuda")
inp = torch.randint(32000, (8, 256), device="cuda")

output = sharded_model(inp)
output.sum().backward()
Expand Down