@@ -1730,10 +1730,8 @@ def forward(
1730
1730
i ,
1731
1731
** kwargs ,
1732
1732
)
1733
- if self .config .mla_config .use_absorb ():
1734
- out_linear_out = fmha_out
1735
- else :
1736
- out_linear_out = self .compute_out_linear (fmha_out , i )
1733
+
1734
+ out_linear_out = self .compute_out_linear (fmha_out , i )
1737
1735
1738
1736
# print(f"{i}: out_linear_out: {out_linear_out}")
1739
1737
@@ -3102,7 +3100,9 @@ def compute_mla_absorb(
3102
3100
ln_out = qkv_out
3103
3101
latent_cache = caches [i ]
3104
3102
3105
- out_linear_out = paddle .zeros (shape = [ln_out .shape [0 ], self .embed_dim ], dtype = ln_out .dtype )
3103
+ fmha_out = paddle .zeros (
3104
+ shape = [ln_out .shape [0 ], self .num_heads * self .config .mla_config .v_head_dim ], dtype = ln_out .dtype
3105
+ )
3106
3106
3107
3107
if kwargs ["max_enc_len_this_time" ]: # prefill phase
3108
3108
query , key , value = self .compute_qkv_linear (ln_out , i , latent_cache = latent_cache , ** kwargs )
@@ -3159,10 +3159,7 @@ def compute_mla_absorb(
3159
3159
3160
3160
fmha_out_prefill = fmha_out_prefill * self .mask_encoder_batch .cast (fmha_out_prefill .dtype )
3161
3161
3162
- out_linear_out_prefill = self .compute_out_linear (fmha_out_prefill , i )
3163
- out_linear_out = out_linear_out + out_linear_out_prefill
3164
-
3165
- # print(f"prefill {i}: out_linear_out: {out_linear_out}")
3162
+ fmha_out = fmha_out + fmha_out_prefill
3166
3163
3167
3164
if kwargs ["max_dec_len_this_time" ]: # decode phase
3168
3165
if self .config .mla_config .q_lora_rank is not None :
@@ -3190,8 +3187,10 @@ def compute_mla_absorb(
3190
3187
epsilon = self ._epsilon ,
3191
3188
begin_norm_axis = 1 ,
3192
3189
)[0 ]
3193
-
3194
- query = paddle .matmul (ln_out_or_q_c , self .q_b_proj_weights [i ])
3190
+ if self .config .mla_config .q_lora_rank is not None :
3191
+ query = paddle .matmul (ln_out_or_q_c , self .q_b_proj_weights [i ])
3192
+ else :
3193
+ query = paddle .matmul (ln_out_or_q_c , self .q_proj_weights [i ])
3195
3194
query = query .reshape ([- 1 , self .num_heads , self .config .mla_config .qk_head_dim ])
3196
3195
query_nope , query_pe = query .split (
3197
3196
[self .config .mla_config .qk_nope_head_dim , self .config .mla_config .qk_rope_head_dim ], axis = - 1
@@ -3282,12 +3281,9 @@ def compute_mla_absorb(
3282
3281
.transpose ([1 , 0 , 2 ])
3283
3282
.reshape ([- 1 , self .num_heads * self .config .mla_config .v_head_dim ])
3284
3283
)
3285
- out_linear_out_decode = paddle .matmul (fmha_out_decode , self .linear_weights [i ])
3286
- out_linear_out = out_linear_out + out_linear_out_decode
3284
+ fmha_out = fmha_out + fmha_out_decode
3287
3285
3288
- # print(f"decode {i}: out_linear_out: {out_linear_out}")
3289
-
3290
- return out_linear_out
3286
+ return fmha_out
3291
3287
3292
3288
def compute_attn (
3293
3289
self ,
@@ -3515,7 +3511,9 @@ def compute_mla_absorb(
3515
3511
ln_out = qkv_out
3516
3512
latent_cache = caches [i ]
3517
3513
3518
- out_linear_out = paddle .zeros (shape = [ln_out .shape [0 ], self .embed_dim ], dtype = ln_out .dtype )
3514
+ fmha_out = paddle .zeros (
3515
+ shape = [ln_out .shape [0 ], self .num_heads * self .config .mla_config .v_head_dim ], dtype = ln_out .dtype
3516
+ )
3519
3517
3520
3518
if kwargs ["max_enc_len_this_time" ]: # prefill phase
3521
3519
query , key , value = self .compute_qkv_linear (ln_out , i , latent_cache = latent_cache , ** kwargs )
@@ -3572,10 +3570,7 @@ def compute_mla_absorb(
3572
3570
3573
3571
fmha_out_prefill = fmha_out_prefill * self .mask_encoder_batch .cast (fmha_out_prefill .dtype )
3574
3572
3575
- out_linear_out_prefill = self .compute_out_linear (fmha_out_prefill , i )
3576
- out_linear_out = out_linear_out + out_linear_out_prefill
3577
-
3578
- # print(f"prefill {i}: out_linear_out: {out_linear_out}")
3573
+ fmha_out = fmha_out + fmha_out_prefill
3579
3574
3580
3575
if kwargs ["max_dec_len_this_time" ]: # decode phase
3581
3576
if self .config .mla_config .q_lora_rank is not None :
@@ -3615,14 +3610,22 @@ def compute_mla_absorb(
3615
3610
epsilon = self ._epsilon ,
3616
3611
begin_norm_axis = 1 ,
3617
3612
)[0 ]
3618
-
3619
- query = weight_only_linear (
3620
- ln_out_or_q_c ,
3621
- weight = self .q_b_proj_weights [i ],
3622
- weight_scale = self .q_b_proj_weights_scale [i ],
3623
- weight_dtype = self .weight_dtype ,
3624
- group_size = self .weightonly_group_size ,
3625
- )
3613
+ if self .config .mla_config .q_lora_rank is not None :
3614
+ query = weight_only_linear (
3615
+ ln_out_or_q_c ,
3616
+ weight = self .q_b_proj_weights [i ],
3617
+ weight_scale = self .q_b_proj_weights_scale [i ],
3618
+ weight_dtype = self .weight_dtype ,
3619
+ group_size = self .weightonly_group_size ,
3620
+ )
3621
+ else :
3622
+ query = weight_only_linear (
3623
+ ln_out_or_q_c ,
3624
+ weight = self .q_proj_weights [i ],
3625
+ weight_scale = self .q_proj_weights_scale [i ],
3626
+ weight_dtype = self .weight_dtype ,
3627
+ group_size = self .weightonly_group_size ,
3628
+ )
3626
3629
query = query .reshape ([- 1 , self .num_heads , self .config .mla_config .qk_head_dim ])
3627
3630
query_nope , query_pe = query .split (
3628
3631
[self .config .mla_config .qk_nope_head_dim , self .config .mla_config .qk_rope_head_dim ], axis = - 1
@@ -3713,18 +3716,10 @@ def compute_mla_absorb(
3713
3716
.transpose ([1 , 0 , 2 ])
3714
3717
.reshape ([- 1 , self .num_heads * self .config .mla_config .v_head_dim ])
3715
3718
)
3716
- out_linear_out_decode = weight_only_linear (
3717
- fmha_out_decode ,
3718
- weight = self .linear_weights [i ],
3719
- weight_scale = self .linear_weights_scale [i ],
3720
- weight_dtype = self .weight_dtype ,
3721
- group_size = self .weightonly_group_size ,
3722
- )
3723
- out_linear_out = out_linear_out + out_linear_out_decode
3724
3719
3725
- # print(f"decode {i}: out_linear_out: {out_linear_out}")
3720
+ fmha_out = fmha_out + fmha_out_decode
3726
3721
3727
- return out_linear_out
3722
+ return fmha_out
3728
3723
3729
3724
3730
3725
class FusedBlockMultiTransformerA8W8 (FusedBlockMultiTransformer , FusedMultiTransformerA8W8 ):
@@ -5193,7 +5188,9 @@ def compute_mla_absorb(
5193
5188
ln_out = qkv_out
5194
5189
latent_cache = caches [i ]
5195
5190
5196
- out_linear_out = paddle .zeros (shape = [ln_out .shape [0 ], self .embed_dim ], dtype = ln_out .dtype )
5191
+ fmha_out = paddle .zeros (
5192
+ shape = [ln_out .shape [0 ], self .num_heads * self .config .mla_config .v_head_dim ], dtype = ln_out .dtype
5193
+ )
5197
5194
5198
5195
if kwargs ["max_enc_len_this_time" ]: # prefill phase
5199
5196
query , key , value = self .compute_qkv_linear (ln_out , i , latent_cache = latent_cache , ** kwargs )
@@ -5250,8 +5247,7 @@ def compute_mla_absorb(
5250
5247
5251
5248
fmha_out_prefill = fmha_out_prefill * self .mask_encoder_batch .cast (fmha_out_prefill .dtype )
5252
5249
5253
- out_linear_out_prefill = self .compute_out_linear (fmha_out_prefill , i )
5254
- out_linear_out = out_linear_out + out_linear_out_prefill
5250
+ fmha_out = fmha_out + fmha_out_prefill
5255
5251
5256
5252
if kwargs ["max_dec_len_this_time" ]: # decode phase
5257
5253
if self .config .mla_config .q_lora_rank is not None :
@@ -5297,15 +5293,26 @@ def compute_mla_absorb(
5297
5293
epsilon = self ._epsilon ,
5298
5294
begin_norm_axis = 1 ,
5299
5295
)[0 ]
5300
- query = self .cutlass_fp8_gemm (
5301
- x = ln_out_or_q_c_fp8 ,
5302
- y = self .q_b_proj_weights [i ],
5303
- x_s = ln_out_or_q_c_scale ,
5304
- y_s = self .q_b_proj_weights_scale [i ],
5305
- bias = None ,
5306
- output_dtype = self ._dtype ,
5307
- act = "identity" ,
5308
- )
5296
+ if self .config .mla_config .q_lora_rank is not None :
5297
+ query = self .cutlass_fp8_gemm (
5298
+ x = ln_out_or_q_c_fp8 ,
5299
+ y = self .q_b_proj_weights [i ],
5300
+ x_s = ln_out_or_q_c_scale ,
5301
+ y_s = self .q_b_proj_weights_scale [i ],
5302
+ bias = None ,
5303
+ output_dtype = self ._dtype ,
5304
+ act = "identity" ,
5305
+ )
5306
+ else :
5307
+ query = self .cutlass_fp8_gemm (
5308
+ x = ln_out_or_q_c_fp8 ,
5309
+ y = self .q_proj_weights [i ],
5310
+ x_s = ln_out_or_q_c_scale ,
5311
+ y_s = self .q_proj_weights_scale [i ],
5312
+ bias = None ,
5313
+ output_dtype = self ._dtype ,
5314
+ act = "identity" ,
5315
+ )
5309
5316
query = query .reshape ([- 1 , self .num_heads , self .config .mla_config .qk_head_dim ])
5310
5317
query_nope , query_pe = query .split (
5311
5318
[self .config .mla_config .qk_nope_head_dim , self .config .mla_config .qk_rope_head_dim ], axis = - 1
@@ -5396,19 +5403,10 @@ def compute_mla_absorb(
5396
5403
.transpose ([1 , 0 , 2 ])
5397
5404
.reshape ([- 1 , self .num_heads * self .config .mla_config .v_head_dim ])
5398
5405
)
5399
- fmha_out_decode_fp8 , fmha_out_decode_scale = self .dynamic_quant (fmha_out_decode )
5400
- out_linear_out_decode = self .cutlass_fp8_gemm (
5401
- x = fmha_out_decode_fp8 ,
5402
- y = self .linear_weights [i ],
5403
- x_s = fmha_out_decode_scale ,
5404
- y_s = self .linear_weights_scale [i ],
5405
- bias = None ,
5406
- output_dtype = self ._dtype ,
5407
- act = "identity" ,
5408
- )
5409
- out_linear_out = out_linear_out + out_linear_out_decode
5410
5406
5411
- return out_linear_out
5407
+ fmha_out = fmha_out + fmha_out_decode
5408
+
5409
+ return fmha_out
5412
5410
5413
5411
def compute_ffn1 (self , tmp_out , i ):
5414
5412
out = self .cutlass_fp8_gemm (
0 commit comments