diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 561793b6a377..963f1c5abf2a 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1211,13 +1211,18 @@ def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse): assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None - return prefill.prefill_main.run( + ret = prefill.prefill_main.run( q=q, k=k, v=v, return_lse=return_softmax_lse, ) + if isinstance(ret, tuple): + # Convert from (q_len, num_heads) to (num_heads, q_len) + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse): assert isinstance(prefill, CudnnPrefillMetadata) @@ -1260,12 +1265,14 @@ def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v): assert isinstance(prefill, FlashInferPrefillMetadata) - return prefill.prefill_chunks[chunk_idx].run( + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, k=k, v=v, return_lse=True, ) + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() def _run_prefill_context_chunk_cudnn(self, prefill: MLACommonPrefillMetadata,