@@ -191,7 +191,6 @@ def scaled_dot_product_attention(
191
191
192
192
colwise_placements = [dist .Replicate (), dist .Shard (1 )]
193
193
rowise_placement = [dist .Replicate (), dist .Shard (0 )]
194
- replicate_placements = [dist .Replicate (), dist .Replicate ()]
195
194
196
195
197
196
class LlamaRMSNormAuto (nn .Layer ):
@@ -242,28 +241,28 @@ def __init__(self, config, ipp: Optional[int] = None):
242
241
self .gate_up_fused_proj .weight = dist .shard_tensor (
243
242
self .gate_up_fused_proj .weight ,
244
243
get_mesh (self .ipp ),
245
- colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
244
+ colwise_placements ,
246
245
)
247
246
else :
248
247
self .gate_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
249
248
self .gate_proj .weight = dist .shard_tensor (
250
249
self .gate_proj .weight ,
251
250
get_mesh (self .ipp ),
252
- colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
251
+ colwise_placements ,
253
252
)
254
253
255
254
self .up_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
256
255
self .up_proj .weight = dist .shard_tensor (
257
256
self .up_proj .weight ,
258
257
get_mesh (self .ipp ),
259
- colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
258
+ colwise_placements ,
260
259
)
261
260
262
261
self .down_proj = nn .Linear (self .intermediate_size , self .hidden_size , bias_attr = False )
263
262
self .down_proj .weight = dist .shard_tensor (
264
263
self .down_proj .weight ,
265
264
get_mesh (self .ipp ),
266
- rowise_placement if self . config . tensor_parallel_degree > 1 else replicate_placements ,
265
+ rowise_placement ,
267
266
)
268
267
269
268
def forward (self , x ):
@@ -323,7 +322,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
323
322
self .qkv_proj .weight = dist .shard_tensor (
324
323
self .qkv_proj .weight ,
325
324
get_mesh (self .ipp ),
326
- colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
325
+ colwise_placements ,
327
326
)
328
327
329
328
else :
@@ -335,7 +334,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
335
334
self .q_proj .weight = dist .shard_tensor (
336
335
self .q_proj .weight ,
337
336
get_mesh (self .ipp ),
338
- colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
337
+ colwise_placements ,
339
338
)
340
339
341
340
self .k_proj = nn .Linear (
@@ -346,7 +345,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
346
345
self .k_proj .weight = dist .shard_tensor (
347
346
self .k_proj .weight ,
348
347
get_mesh (self .ipp ),
349
- colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
348
+ colwise_placements ,
350
349
)
351
350
352
351
self .v_proj = nn .Linear (
@@ -357,7 +356,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
357
356
self .v_proj .weight = dist .shard_tensor (
358
357
self .v_proj .weight ,
359
358
get_mesh (self .ipp ),
360
- colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
359
+ colwise_placements ,
361
360
)
362
361
363
362
self .o_proj = nn .Linear (
@@ -368,7 +367,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
368
367
self .o_proj .weight = dist .shard_tensor (
369
368
self .o_proj .weight ,
370
369
get_mesh (self .ipp ),
371
- rowise_placement if self . config . tensor_parallel_degree > 1 else replicate_placements ,
370
+ rowise_placement ,
372
371
)
373
372
374
373
if config .rope :
@@ -1220,7 +1219,7 @@ def __init__(self, config: LlamaConfig):
1220
1219
self .weight = dist .shard_tensor (
1221
1220
self .weight ,
1222
1221
get_mesh (- 1 ),
1223
- colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
1222
+ colwise_placements ,
1224
1223
)
1225
1224
1226
1225
def forward (self , hidden_states , tensor_parallel_output = None ):
0 commit comments