From b2cfc3190efb16ea1f26fcfbebbafae6ed59f103 Mon Sep 17 00:00:00 2001 From: ignaciosica Date: Thu, 11 Sep 2025 16:45:10 -0300 Subject: [PATCH] move repeat_interleave after movedim Signed-off-by: ignaciosica --- vllm/v1/attention/backends/cpu_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index ab87f3bb4e3c..6627164c9879 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -641,10 +641,6 @@ def _run_sdpa_forward( attn_metadata: TorchSDPAMetadata, attn_type: str = AttentionType.DECODER, ) -> None: - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - attn_masks = attn_metadata.get_attn_bias(attn_type) if attn_masks is None: if self.alibi_slopes is not None: @@ -665,6 +661,10 @@ def _run_sdpa_forward( key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) + value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) + causal_attn = (attn_type == AttentionType.DECODER) seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)