diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index ab87f3bb4e3c..6627164c9879 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -641,10 +641,6 @@ def _run_sdpa_forward( attn_metadata: TorchSDPAMetadata, attn_type: str = AttentionType.DECODER, ) -> None: - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - attn_masks = attn_metadata.get_attn_bias(attn_type) if attn_masks is None: if self.alibi_slopes is not None: @@ -665,6 +661,10 @@ def _run_sdpa_forward( key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) + value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) + causal_attn = (attn_type == AttentionType.DECODER) seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)