Skip to content

Commit 22e4f9a

Browse files
authoredMar 19, 2025
Revert "optimize ds3 attention impl (#10200)" (#10208)
This reverts commit 58edb00.
1 parent 642869e commit 22e4f9a

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed
 

‎paddlenlp/transformers/deepseek_v2/modeling.py

+14
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,11 @@ def scaled_dot_product_attention(
245245
q_head_dim = query_states.shape[-1]
246246
softmax_scale = softmax_scale * (q_head_dim**0.5)
247247
query_states = query_states * softmax_scale
248+
value_padding = paddle.zeros(
249+
[bsz, kv_seq_len, v_num_heads, head_dim - v_head_dim],
250+
dtype=value_states.dtype,
251+
)
252+
value_states = paddle.concat([value_states, value_padding], axis=-1)
248253

249254
outputs = fusion_ops.fusion_flash_attention(
250255
query_states,
@@ -257,6 +262,15 @@ def scaled_dot_product_attention(
257262
sequence_parallel=sequence_parallel,
258263
)
259264

265+
if isinstance(outputs, tuple):
266+
outputs[0] = outputs[0].reshape([bsz, q_len, v_num_heads, head_dim])
267+
outputs[0] = outputs[0][..., :v_head_dim]
268+
outputs[0] = outputs[0].reshape([bsz, q_len, -1])
269+
else:
270+
outputs = outputs.reshape([bsz, q_len, v_num_heads, head_dim])
271+
outputs = outputs[..., :v_head_dim]
272+
outputs = outputs.reshape([bsz, q_len, -1])
273+
260274
if sequence_parallel:
261275
outputs = outputs.reshape([bsz * q_len, v_head_dim * num_heads])
262276

0 commit comments

Comments
 (0)
Failed to load comments.