Skip to content

Commit

Permalink
[Tensor Parallel] update examples to simplify embedding + first trans…
Browse files Browse the repository at this point in the history
…former block

ghstack-source-id: 4396c58e93f7bfd53ae261f6e305b86a8da29d5e
Pull Request resolved: #1259
  • Loading branch information
tianyu-l committed May 16, 2024
1 parent 851c4cf commit cd29c12
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,19 @@
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate()
),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(1), None),
use_local_output=True,
),
}
)

for layer_id, transformer_block in enumerate(model.layers):
layer_tp_plan = {
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
Expand All @@ -131,15 +128,14 @@
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
Expand Down

0 comments on commit cd29c12

Please sign in to comment.