Skip to content

Commit e0921f0

Browse files
authored
[Auto-Parallel] optimize llama-7b benchmark in temporary solution to avoid unnecessary communication (#10671)
1 parent 9ba4e7e commit e0921f0

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def scaled_dot_product_attention(
191191

192192
colwise_placements = [dist.Replicate(), dist.Shard(1)]
193193
rowise_placement = [dist.Replicate(), dist.Shard(0)]
194+
replicate_placements = [dist.Replicate(), dist.Replicate()]
194195

195196

196197
class LlamaRMSNormAuto(nn.Layer):
@@ -241,28 +242,28 @@ def __init__(self, config, ipp: Optional[int] = None):
241242
self.gate_up_fused_proj.weight = dist.shard_tensor(
242243
self.gate_up_fused_proj.weight,
243244
get_mesh(self.ipp),
244-
colwise_placements,
245+
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
245246
)
246247
else:
247248
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
248249
self.gate_proj.weight = dist.shard_tensor(
249250
self.gate_proj.weight,
250251
get_mesh(self.ipp),
251-
colwise_placements,
252+
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
252253
)
253254

254255
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
255256
self.up_proj.weight = dist.shard_tensor(
256257
self.up_proj.weight,
257258
get_mesh(self.ipp),
258-
colwise_placements,
259+
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
259260
)
260261

261262
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
262263
self.down_proj.weight = dist.shard_tensor(
263264
self.down_proj.weight,
264265
get_mesh(self.ipp),
265-
rowise_placement,
266+
rowise_placement if self.config.tensor_parallel_degree > 1 else replicate_placements,
266267
)
267268

268269
def forward(self, x):
@@ -322,7 +323,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
322323
self.qkv_proj.weight = dist.shard_tensor(
323324
self.qkv_proj.weight,
324325
get_mesh(self.ipp),
325-
colwise_placements,
326+
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
326327
)
327328

328329
else:
@@ -334,7 +335,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
334335
self.q_proj.weight = dist.shard_tensor(
335336
self.q_proj.weight,
336337
get_mesh(self.ipp),
337-
colwise_placements,
338+
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
338339
)
339340

340341
self.k_proj = nn.Linear(
@@ -345,7 +346,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
345346
self.k_proj.weight = dist.shard_tensor(
346347
self.k_proj.weight,
347348
get_mesh(self.ipp),
348-
colwise_placements,
349+
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
349350
)
350351

351352
self.v_proj = nn.Linear(
@@ -356,7 +357,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
356357
self.v_proj.weight = dist.shard_tensor(
357358
self.v_proj.weight,
358359
get_mesh(self.ipp),
359-
colwise_placements,
360+
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
360361
)
361362

362363
self.o_proj = nn.Linear(
@@ -367,7 +368,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
367368
self.o_proj.weight = dist.shard_tensor(
368369
self.o_proj.weight,
369370
get_mesh(self.ipp),
370-
rowise_placement,
371+
rowise_placement if self.config.tensor_parallel_degree > 1 else replicate_placements,
371372
)
372373

373374
if config.rope:
@@ -1219,7 +1220,7 @@ def __init__(self, config: LlamaConfig):
12191220
self.weight = dist.shard_tensor(
12201221
self.weight,
12211222
get_mesh(-1),
1222-
colwise_placements,
1223+
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
12231224
)
12241225

12251226
def forward(self, hidden_states, tensor_parallel_output=None):

0 commit comments

Comments
 (0)