@@ -245,6 +245,11 @@ def scaled_dot_product_attention(
245
245
q_head_dim = query_states .shape [- 1 ]
246
246
softmax_scale = softmax_scale * (q_head_dim ** 0.5 )
247
247
query_states = query_states * softmax_scale
248
+ value_padding = paddle .zeros (
249
+ [bsz , kv_seq_len , v_num_heads , head_dim - v_head_dim ],
250
+ dtype = value_states .dtype ,
251
+ )
252
+ value_states = paddle .concat ([value_states , value_padding ], axis = - 1 )
248
253
249
254
outputs = fusion_ops .fusion_flash_attention (
250
255
query_states ,
@@ -257,6 +262,15 @@ def scaled_dot_product_attention(
257
262
sequence_parallel = sequence_parallel ,
258
263
)
259
264
265
+ if isinstance (outputs , tuple ):
266
+ outputs [0 ] = outputs [0 ].reshape ([bsz , q_len , v_num_heads , head_dim ])
267
+ outputs [0 ] = outputs [0 ][..., :v_head_dim ]
268
+ outputs [0 ] = outputs [0 ].reshape ([bsz , q_len , - 1 ])
269
+ else :
270
+ outputs = outputs .reshape ([bsz , q_len , v_num_heads , head_dim ])
271
+ outputs = outputs [..., :v_head_dim ]
272
+ outputs = outputs .reshape ([bsz , q_len , - 1 ])
273
+
260
274
if sequence_parallel :
261
275
outputs = outputs .reshape ([bsz * q_len , v_head_dim * num_heads ])
262
276
0 commit comments