Skip to content

Commit 712495c

Browse files
authored
[MLA] move compute_out_linear out and fix bug when q_lora_rank is None (#10275)
1 parent 2d40d82 commit 712495c

File tree

1 file changed

+61
-63
lines changed

1 file changed

+61
-63
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,10 +1730,8 @@ def forward(
17301730
i,
17311731
**kwargs,
17321732
)
1733-
if self.config.mla_config.use_absorb():
1734-
out_linear_out = fmha_out
1735-
else:
1736-
out_linear_out = self.compute_out_linear(fmha_out, i)
1733+
1734+
out_linear_out = self.compute_out_linear(fmha_out, i)
17371735

17381736
# print(f"{i}: out_linear_out: {out_linear_out}")
17391737

@@ -3102,7 +3100,9 @@ def compute_mla_absorb(
31023100
ln_out = qkv_out
31033101
latent_cache = caches[i]
31043102

3105-
out_linear_out = paddle.zeros(shape=[ln_out.shape[0], self.embed_dim], dtype=ln_out.dtype)
3103+
fmha_out = paddle.zeros(
3104+
shape=[ln_out.shape[0], self.num_heads * self.config.mla_config.v_head_dim], dtype=ln_out.dtype
3105+
)
31063106

31073107
if kwargs["max_enc_len_this_time"]: # prefill phase
31083108
query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)
@@ -3159,10 +3159,7 @@ def compute_mla_absorb(
31593159

31603160
fmha_out_prefill = fmha_out_prefill * self.mask_encoder_batch.cast(fmha_out_prefill.dtype)
31613161

3162-
out_linear_out_prefill = self.compute_out_linear(fmha_out_prefill, i)
3163-
out_linear_out = out_linear_out + out_linear_out_prefill
3164-
3165-
# print(f"prefill {i}: out_linear_out: {out_linear_out}")
3162+
fmha_out = fmha_out + fmha_out_prefill
31663163

31673164
if kwargs["max_dec_len_this_time"]: # decode phase
31683165
if self.config.mla_config.q_lora_rank is not None:
@@ -3190,8 +3187,10 @@ def compute_mla_absorb(
31903187
epsilon=self._epsilon,
31913188
begin_norm_axis=1,
31923189
)[0]
3193-
3194-
query = paddle.matmul(ln_out_or_q_c, self.q_b_proj_weights[i])
3190+
if self.config.mla_config.q_lora_rank is not None:
3191+
query = paddle.matmul(ln_out_or_q_c, self.q_b_proj_weights[i])
3192+
else:
3193+
query = paddle.matmul(ln_out_or_q_c, self.q_proj_weights[i])
31953194
query = query.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
31963195
query_nope, query_pe = query.split(
31973196
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
@@ -3282,12 +3281,9 @@ def compute_mla_absorb(
32823281
.transpose([1, 0, 2])
32833282
.reshape([-1, self.num_heads * self.config.mla_config.v_head_dim])
32843283
)
3285-
out_linear_out_decode = paddle.matmul(fmha_out_decode, self.linear_weights[i])
3286-
out_linear_out = out_linear_out + out_linear_out_decode
3284+
fmha_out = fmha_out + fmha_out_decode
32873285

3288-
# print(f"decode {i}: out_linear_out: {out_linear_out}")
3289-
3290-
return out_linear_out
3286+
return fmha_out
32913287

32923288
def compute_attn(
32933289
self,
@@ -3515,7 +3511,9 @@ def compute_mla_absorb(
35153511
ln_out = qkv_out
35163512
latent_cache = caches[i]
35173513

3518-
out_linear_out = paddle.zeros(shape=[ln_out.shape[0], self.embed_dim], dtype=ln_out.dtype)
3514+
fmha_out = paddle.zeros(
3515+
shape=[ln_out.shape[0], self.num_heads * self.config.mla_config.v_head_dim], dtype=ln_out.dtype
3516+
)
35193517

35203518
if kwargs["max_enc_len_this_time"]: # prefill phase
35213519
query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)
@@ -3572,10 +3570,7 @@ def compute_mla_absorb(
35723570

35733571
fmha_out_prefill = fmha_out_prefill * self.mask_encoder_batch.cast(fmha_out_prefill.dtype)
35743572

3575-
out_linear_out_prefill = self.compute_out_linear(fmha_out_prefill, i)
3576-
out_linear_out = out_linear_out + out_linear_out_prefill
3577-
3578-
# print(f"prefill {i}: out_linear_out: {out_linear_out}")
3573+
fmha_out = fmha_out + fmha_out_prefill
35793574

35803575
if kwargs["max_dec_len_this_time"]: # decode phase
35813576
if self.config.mla_config.q_lora_rank is not None:
@@ -3615,14 +3610,22 @@ def compute_mla_absorb(
36153610
epsilon=self._epsilon,
36163611
begin_norm_axis=1,
36173612
)[0]
3618-
3619-
query = weight_only_linear(
3620-
ln_out_or_q_c,
3621-
weight=self.q_b_proj_weights[i],
3622-
weight_scale=self.q_b_proj_weights_scale[i],
3623-
weight_dtype=self.weight_dtype,
3624-
group_size=self.weightonly_group_size,
3625-
)
3613+
if self.config.mla_config.q_lora_rank is not None:
3614+
query = weight_only_linear(
3615+
ln_out_or_q_c,
3616+
weight=self.q_b_proj_weights[i],
3617+
weight_scale=self.q_b_proj_weights_scale[i],
3618+
weight_dtype=self.weight_dtype,
3619+
group_size=self.weightonly_group_size,
3620+
)
3621+
else:
3622+
query = weight_only_linear(
3623+
ln_out_or_q_c,
3624+
weight=self.q_proj_weights[i],
3625+
weight_scale=self.q_proj_weights_scale[i],
3626+
weight_dtype=self.weight_dtype,
3627+
group_size=self.weightonly_group_size,
3628+
)
36263629
query = query.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
36273630
query_nope, query_pe = query.split(
36283631
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
@@ -3713,18 +3716,10 @@ def compute_mla_absorb(
37133716
.transpose([1, 0, 2])
37143717
.reshape([-1, self.num_heads * self.config.mla_config.v_head_dim])
37153718
)
3716-
out_linear_out_decode = weight_only_linear(
3717-
fmha_out_decode,
3718-
weight=self.linear_weights[i],
3719-
weight_scale=self.linear_weights_scale[i],
3720-
weight_dtype=self.weight_dtype,
3721-
group_size=self.weightonly_group_size,
3722-
)
3723-
out_linear_out = out_linear_out + out_linear_out_decode
37243719

3725-
# print(f"decode {i}: out_linear_out: {out_linear_out}")
3720+
fmha_out = fmha_out + fmha_out_decode
37263721

3727-
return out_linear_out
3722+
return fmha_out
37283723

37293724

37303725
class FusedBlockMultiTransformerA8W8(FusedBlockMultiTransformer, FusedMultiTransformerA8W8):
@@ -5193,7 +5188,9 @@ def compute_mla_absorb(
51935188
ln_out = qkv_out
51945189
latent_cache = caches[i]
51955190

5196-
out_linear_out = paddle.zeros(shape=[ln_out.shape[0], self.embed_dim], dtype=ln_out.dtype)
5191+
fmha_out = paddle.zeros(
5192+
shape=[ln_out.shape[0], self.num_heads * self.config.mla_config.v_head_dim], dtype=ln_out.dtype
5193+
)
51975194

51985195
if kwargs["max_enc_len_this_time"]: # prefill phase
51995196
query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)
@@ -5250,8 +5247,7 @@ def compute_mla_absorb(
52505247

52515248
fmha_out_prefill = fmha_out_prefill * self.mask_encoder_batch.cast(fmha_out_prefill.dtype)
52525249

5253-
out_linear_out_prefill = self.compute_out_linear(fmha_out_prefill, i)
5254-
out_linear_out = out_linear_out + out_linear_out_prefill
5250+
fmha_out = fmha_out + fmha_out_prefill
52555251

52565252
if kwargs["max_dec_len_this_time"]: # decode phase
52575253
if self.config.mla_config.q_lora_rank is not None:
@@ -5297,15 +5293,26 @@ def compute_mla_absorb(
52975293
epsilon=self._epsilon,
52985294
begin_norm_axis=1,
52995295
)[0]
5300-
query = self.cutlass_fp8_gemm(
5301-
x=ln_out_or_q_c_fp8,
5302-
y=self.q_b_proj_weights[i],
5303-
x_s=ln_out_or_q_c_scale,
5304-
y_s=self.q_b_proj_weights_scale[i],
5305-
bias=None,
5306-
output_dtype=self._dtype,
5307-
act="identity",
5308-
)
5296+
if self.config.mla_config.q_lora_rank is not None:
5297+
query = self.cutlass_fp8_gemm(
5298+
x=ln_out_or_q_c_fp8,
5299+
y=self.q_b_proj_weights[i],
5300+
x_s=ln_out_or_q_c_scale,
5301+
y_s=self.q_b_proj_weights_scale[i],
5302+
bias=None,
5303+
output_dtype=self._dtype,
5304+
act="identity",
5305+
)
5306+
else:
5307+
query = self.cutlass_fp8_gemm(
5308+
x=ln_out_or_q_c_fp8,
5309+
y=self.q_proj_weights[i],
5310+
x_s=ln_out_or_q_c_scale,
5311+
y_s=self.q_proj_weights_scale[i],
5312+
bias=None,
5313+
output_dtype=self._dtype,
5314+
act="identity",
5315+
)
53095316
query = query.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
53105317
query_nope, query_pe = query.split(
53115318
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
@@ -5396,19 +5403,10 @@ def compute_mla_absorb(
53965403
.transpose([1, 0, 2])
53975404
.reshape([-1, self.num_heads * self.config.mla_config.v_head_dim])
53985405
)
5399-
fmha_out_decode_fp8, fmha_out_decode_scale = self.dynamic_quant(fmha_out_decode)
5400-
out_linear_out_decode = self.cutlass_fp8_gemm(
5401-
x=fmha_out_decode_fp8,
5402-
y=self.linear_weights[i],
5403-
x_s=fmha_out_decode_scale,
5404-
y_s=self.linear_weights_scale[i],
5405-
bias=None,
5406-
output_dtype=self._dtype,
5407-
act="identity",
5408-
)
5409-
out_linear_out = out_linear_out + out_linear_out_decode
54105406

5411-
return out_linear_out
5407+
fmha_out = fmha_out + fmha_out_decode
5408+
5409+
return fmha_out
54125410

54135411
def compute_ffn1(self, tmp_out, i):
54145412
out = self.cutlass_fp8_gemm(

0 commit comments

Comments
 (0)