|
10 | 10 | parallelize_module, |
11 | 11 | PrepareModuleInput, |
12 | 12 | RowwiseParallel, |
13 | | - SequenceParallel, |
14 | 13 | ) |
15 | 14 |
|
16 | 15 | import torch.nn as nn |
| 16 | +from torch.distributed._tensor import Replicate, Shard |
17 | 17 | from distributed.parallel_config import ParallelDims |
18 | 18 | from torch.distributed.device_mesh import DeviceMesh |
| 19 | +from distributed.utils import logger |
19 | 20 |
|
20 | 21 |
|
21 | 22 | def apply_tp( |
@@ -43,53 +44,55 @@ def apply_tp( |
43 | 44 |
|
44 | 45 | tp_mesh = world_mesh["tp"] |
45 | 46 |
|
46 | | - # 1. Parallelize the first embedding and the last linear proj layer |
47 | | - # 2. Parallelize the root norm layer over the sequence dim |
48 | | - # 3. Shard the first transformer block's inputs |
49 | | - model = parallelize_module( |
50 | | - model, |
51 | | - tp_mesh, |
52 | | - { |
53 | | - "tok_embeddings": RowwiseParallel( |
54 | | - input_layouts=Replicate(), |
55 | | - output_layouts=Shard(1), |
56 | | - ), |
57 | | - "output": ColwiseParallel( |
58 | | - input_layouts=Shard(1), |
59 | | - output_layouts=Replicate(), |
60 | | - use_local_output=True, |
61 | | - ), |
62 | | - "norm": SequenceParallel(), |
63 | | - }, |
64 | | - ) |
65 | | - |
66 | | - # Apply tensor + sequence parallelism to every transformer block |
67 | | - for layer_id, transformer_block in model.layers.items(): |
| 47 | + # TODO: To figure out the TP for the tok_embedding and the linear proj layer. |
| 48 | + # # 1. Parallelize the first embedding and the last linear proj layer |
| 49 | + # # 2. Shard the first transformer block's inputs |
| 50 | + # model = parallelize_module( |
| 51 | + # model, |
| 52 | + # tp_mesh, |
| 53 | + # { |
| 54 | + # "tok_embeddings": RowwiseParallel( |
| 55 | + # input_layouts=Replicate(), |
| 56 | + # output_layouts=Replicate(), |
| 57 | + # ), |
| 58 | + # "output": ColwiseParallel( |
| 59 | + # input_layouts=Shard(1), |
| 60 | + # output_layouts=Replicate(), |
| 61 | + # use_local_output=True, |
| 62 | + # ), |
| 63 | + # }, |
| 64 | + # ) |
| 65 | + |
| 66 | + # Apply tensor parallelism to every transformer block |
| 67 | + for transformer_block in model.layers: |
68 | 68 | layer_plan = { |
69 | | - "attention": prepare_module_input( |
70 | | - input_layouts=(Shard(1), None), |
| 69 | + "attention": PrepareModuleInput( |
| 70 | + input_layouts=(Replicate(), None), |
71 | 71 | desired_input_layouts=(Replicate(), None), |
72 | 72 | ), |
73 | 73 | "attention.wq": ColwiseParallel(), |
74 | 74 | "attention.wk": ColwiseParallel(), |
75 | 75 | "attention.wv": ColwiseParallel(), |
76 | | - "attention.wo": RowwiseParallel(output_layouts=Shard(1)), |
77 | | - "attention_norm": SequenceParallel(), |
78 | | - "feed_forward": prepare_module_input( |
79 | | - input_layouts=(Shard(1),), |
| 76 | + "attention.wo": RowwiseParallel( |
| 77 | + output_layouts=Replicate(), |
| 78 | + use_local_output=True, |
| 79 | + ), |
| 80 | + "feed_forward": PrepareModuleInput( |
| 81 | + input_layouts=(Replicate(),), |
80 | 82 | desired_input_layouts=(Replicate(),), |
81 | 83 | ), |
82 | 84 | "feed_forward.w1": ColwiseParallel(), |
83 | | - "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), |
| 85 | + "feed_forward.w2": RowwiseParallel( |
| 86 | + output_layouts=Replicate(), |
| 87 | + use_local_output=True |
| 88 | + ), |
84 | 89 | "feed_forward.w3": ColwiseParallel(), |
85 | | - "ffn_norm": SequenceParallel(), |
86 | 90 | } |
87 | 91 |
|
88 | 92 | # Adjust attention module to use the local number of heads |
89 | 93 | attn_layer = transformer_block.attention |
90 | 94 | attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() |
91 | 95 | attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size() |
92 | | - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() |
93 | 96 |
|
94 | 97 | parallelize_module( |
95 | 98 | module=transformer_block, |
@@ -125,6 +128,6 @@ def parallelize_llama( |
125 | 128 | """ |
126 | 129 |
|
127 | 130 | if parallel_dims.tp_enabled: |
128 | | - model = apply_tp(model, world_mesh, parallel_dims) |
| 131 | + model = apply_tp(model, world_mesh) |
129 | 132 |
|
130 | 133 | return model |
0 commit comments