@@ -191,6 +191,7 @@ def scaled_dot_product_attention(
191
191
192
192
colwise_placements = [dist .Replicate (), dist .Shard (1 )]
193
193
rowise_placement = [dist .Replicate (), dist .Shard (0 )]
194
+ replicate_placements = [dist .Replicate (), dist .Replicate ()]
194
195
195
196
196
197
class LlamaRMSNormAuto (nn .Layer ):
@@ -241,28 +242,28 @@ def __init__(self, config, ipp: Optional[int] = None):
241
242
self .gate_up_fused_proj .weight = dist .shard_tensor (
242
243
self .gate_up_fused_proj .weight ,
243
244
get_mesh (self .ipp ),
244
- colwise_placements ,
245
+ colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
245
246
)
246
247
else :
247
248
self .gate_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
248
249
self .gate_proj .weight = dist .shard_tensor (
249
250
self .gate_proj .weight ,
250
251
get_mesh (self .ipp ),
251
- colwise_placements ,
252
+ colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
252
253
)
253
254
254
255
self .up_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
255
256
self .up_proj .weight = dist .shard_tensor (
256
257
self .up_proj .weight ,
257
258
get_mesh (self .ipp ),
258
- colwise_placements ,
259
+ colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
259
260
)
260
261
261
262
self .down_proj = nn .Linear (self .intermediate_size , self .hidden_size , bias_attr = False )
262
263
self .down_proj .weight = dist .shard_tensor (
263
264
self .down_proj .weight ,
264
265
get_mesh (self .ipp ),
265
- rowise_placement ,
266
+ rowise_placement if self . config . tensor_parallel_degree > 1 else replicate_placements ,
266
267
)
267
268
268
269
def forward (self , x ):
@@ -322,7 +323,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
322
323
self .qkv_proj .weight = dist .shard_tensor (
323
324
self .qkv_proj .weight ,
324
325
get_mesh (self .ipp ),
325
- colwise_placements ,
326
+ colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
326
327
)
327
328
328
329
else :
@@ -334,7 +335,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
334
335
self .q_proj .weight = dist .shard_tensor (
335
336
self .q_proj .weight ,
336
337
get_mesh (self .ipp ),
337
- colwise_placements ,
338
+ colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
338
339
)
339
340
340
341
self .k_proj = nn .Linear (
@@ -345,7 +346,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
345
346
self .k_proj .weight = dist .shard_tensor (
346
347
self .k_proj .weight ,
347
348
get_mesh (self .ipp ),
348
- colwise_placements ,
349
+ colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
349
350
)
350
351
351
352
self .v_proj = nn .Linear (
@@ -356,7 +357,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
356
357
self .v_proj .weight = dist .shard_tensor (
357
358
self .v_proj .weight ,
358
359
get_mesh (self .ipp ),
359
- colwise_placements ,
360
+ colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
360
361
)
361
362
362
363
self .o_proj = nn .Linear (
@@ -367,7 +368,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
367
368
self .o_proj .weight = dist .shard_tensor (
368
369
self .o_proj .weight ,
369
370
get_mesh (self .ipp ),
370
- rowise_placement ,
371
+ rowise_placement if self . config . tensor_parallel_degree > 1 else replicate_placements ,
371
372
)
372
373
373
374
if config .rope :
@@ -1219,7 +1220,7 @@ def __init__(self, config: LlamaConfig):
1219
1220
self .weight = dist .shard_tensor (
1220
1221
self .weight ,
1221
1222
get_mesh (- 1 ),
1222
- colwise_placements ,
1223
+ colwise_placements if self . config . tensor_parallel_degree > 1 else replicate_placements ,
1223
1224
)
1224
1225
1225
1226
def forward (self , hidden_states , tensor_parallel_output = None ):
0 commit comments