Skip to content

Commit b550133

Browse files
authored
Revert "[Auto-Parallel] optimize llama-7b benchmark in temporary solution to …" (#10682)
This reverts commit e0921f0.
1 parent c1322be commit b550133

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def scaled_dot_product_attention(
191191

192192
colwise_placements = [dist.Replicate(), dist.Shard(1)]
193193
rowise_placement = [dist.Replicate(), dist.Shard(0)]
194-
replicate_placements = [dist.Replicate(), dist.Replicate()]
195194

196195

197196
class LlamaRMSNormAuto(nn.Layer):
@@ -242,28 +241,28 @@ def __init__(self, config, ipp: Optional[int] = None):
242241
self.gate_up_fused_proj.weight = dist.shard_tensor(
243242
self.gate_up_fused_proj.weight,
244243
get_mesh(self.ipp),
245-
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
244+
colwise_placements,
246245
)
247246
else:
248247
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
249248
self.gate_proj.weight = dist.shard_tensor(
250249
self.gate_proj.weight,
251250
get_mesh(self.ipp),
252-
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
251+
colwise_placements,
253252
)
254253

255254
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
256255
self.up_proj.weight = dist.shard_tensor(
257256
self.up_proj.weight,
258257
get_mesh(self.ipp),
259-
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
258+
colwise_placements,
260259
)
261260

262261
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
263262
self.down_proj.weight = dist.shard_tensor(
264263
self.down_proj.weight,
265264
get_mesh(self.ipp),
266-
rowise_placement if self.config.tensor_parallel_degree > 1 else replicate_placements,
265+
rowise_placement,
267266
)
268267

269268
def forward(self, x):
@@ -323,7 +322,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
323322
self.qkv_proj.weight = dist.shard_tensor(
324323
self.qkv_proj.weight,
325324
get_mesh(self.ipp),
326-
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
325+
colwise_placements,
327326
)
328327

329328
else:
@@ -335,7 +334,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
335334
self.q_proj.weight = dist.shard_tensor(
336335
self.q_proj.weight,
337336
get_mesh(self.ipp),
338-
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
337+
colwise_placements,
339338
)
340339

341340
self.k_proj = nn.Linear(
@@ -346,7 +345,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
346345
self.k_proj.weight = dist.shard_tensor(
347346
self.k_proj.weight,
348347
get_mesh(self.ipp),
349-
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
348+
colwise_placements,
350349
)
351350

352351
self.v_proj = nn.Linear(
@@ -357,7 +356,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
357356
self.v_proj.weight = dist.shard_tensor(
358357
self.v_proj.weight,
359358
get_mesh(self.ipp),
360-
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
359+
colwise_placements,
361360
)
362361

363362
self.o_proj = nn.Linear(
@@ -368,7 +367,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
368367
self.o_proj.weight = dist.shard_tensor(
369368
self.o_proj.weight,
370369
get_mesh(self.ipp),
371-
rowise_placement if self.config.tensor_parallel_degree > 1 else replicate_placements,
370+
rowise_placement,
372371
)
373372

374373
if config.rope:
@@ -1220,7 +1219,7 @@ def __init__(self, config: LlamaConfig):
12201219
self.weight = dist.shard_tensor(
12211220
self.weight,
12221221
get_mesh(-1),
1223-
colwise_placements if self.config.tensor_parallel_degree > 1 else replicate_placements,
1222+
colwise_placements,
12241223
)
12251224

12261225
def forward(self, hidden_states, tensor_parallel_output=None):

0 commit comments

Comments
 (0)