From f557af987ca5a57d71efdfee3091b28e15c4f1f2 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 18 Sep 2025 15:54:44 +0000 Subject: [PATCH 1/4] Refactor reorder_batch_threshold for spec Co-authored-by: lhsjohn Signed-off-by: Benjamin Chislett --- vllm/v1/attention/backends/flashinfer.py | 2 +- vllm/v1/attention/backends/gdn_attn.py | 6 +- vllm/v1/attention/backends/linear_attn.py | 3 +- vllm/v1/attention/backends/mamba_attn.py | 2 +- vllm/v1/attention/backends/mla/common.py | 4 +- .../attention/backends/mla/flashattn_mla.py | 4 +- vllm/v1/attention/backends/short_conv_attn.py | 4 +- vllm/v1/attention/backends/utils.py | 80 +++++++++++++++++-- vllm/v1/attention/backends/xformers.py | 4 +- 9 files changed, 87 insertions(+), 22 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index dda6dd4fbea7..2f18a3721fd6 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -240,7 +240,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index ba89f93e8b56..ebfd34bfaab5 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Backend for GatedDeltaNet attention.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -56,7 +56,7 @@ class GDNAttentionMetadataBuilder( cudagraph_support = AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -70,7 +70,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], else: self.num_spec = 0 self.use_spec_decode = self.num_spec > 0 - self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc] + self._init_reorder_batch_threshold(1, self.use_spec_decode) self.use_full_cuda_graph = \ self.compilation_config.cudagraph_mode.has_full_cudagraphs() diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 3ff201d83a79..0dc62d668020 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar import torch @@ -35,7 +34,7 @@ class LinearAttentionMetadata: class LinearAttentionMetadataBuilder( AttentionMetadataBuilder[LinearAttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 9970331a6042..ef342ce421ae 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -16,7 +16,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a990cb2f1a97..3e23afb56aec 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,7 +190,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import ClassVar, Generic, Optional, TypeVar, Union +from typing import Generic, Optional, TypeVar, Union import torch from tqdm import tqdm @@ -433,7 +433,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 472095e13615..7ef869d2626b 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -67,7 +67,7 @@ class FlashAttnMLAMetadataBuilder( cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: ClassVar[int] = 512 + reorder_batch_threshold: int = 512 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -101,7 +101,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # TODO(lucas): Until we add support for the DCP custom masking we need # to restrict decodes to q_len == 1 when DCP is enabled. - self.__class__.reorder_batch_threshold = 1 \ + self.reorder_batch_threshold = 1 \ if get_dcp_group().world_size > 1 else self.reorder_batch_threshold def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index f5ad65b02b4d..f79bc87c34c6 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -41,7 +41,7 @@ class ShortConvAttentionMetadata: class ShortConvAttentionMetadataBuilder( AttentionMetadataBuilder[ShortConvAttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 63326d19194f..e96fc5ace6e4 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -195,7 +195,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. - reorder_batch_threshold: ClassVar[Optional[int]] = None + reorder_batch_threshold: Optional[int] = None + # Does this backend/builder support issuing decodes with (uniform) qlen > 1? + supports_spec_as_decode: bool = False @abstractmethod def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], @@ -205,6 +207,23 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.vllm_config = vllm_config self.device = device + def _init_reorder_batch_threshold( + self, + reorder_batch_threshold: int = 1, + supports_spec_as_decode: bool = False) -> None: + self.reorder_batch_threshold = reorder_batch_threshold + self.supports_spec_as_decode = supports_spec_as_decode + if self.reorder_batch_threshold is not None \ + and self.supports_spec_as_decode: + # If the backend supports spec-as-decode kernels, then we can set + # the reorder_batch_threshold based on the number of speculative + # tokens from the config. + speculative_config = self.vllm_config.speculative_config + if (speculative_config is not None + and speculative_config.num_speculative_tokens is not None): + self.reorder_batch_threshold = \ + 1 + speculative_config.num_speculative_tokens + @abstractmethod def build(self, common_prefix_len: int, @@ -662,9 +681,9 @@ def subclass_attention_backend( def split_decodes_and_prefills( - common_attn_metadata: CommonAttentionMetadata, - decode_threshold: int = 1, -) -> tuple[int, int, int, int]: + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, + require_uniform: bool = False) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. @@ -673,6 +692,9 @@ def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata object containing the batch metadata. decode_threshold: The maximum query length to be considered a decode. + require_uniform: If True, requires that all decode requests have the + same query length. When set, some queries may be considered prefills + even if they are <= decode_threshold, in order to ensure uniformity. Returns: num_decodes: The number of decode requests. @@ -685,16 +707,26 @@ def split_decodes_and_prefills( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - if max_query_len <= decode_threshold: + if max_query_len <= decode_threshold and \ + (not require_uniform or decode_threshold <= 1): return num_reqs, 0, num_tokens, 0 query_lens = query_start_loc[1:] - query_start_loc[:-1] - is_prefill = query_lens > decode_threshold + if query_lens[0].item() > decode_threshold: + # first request is not decode, so no decode requests + return 0, num_reqs, 0, num_tokens + + if require_uniform: + is_prefill = query_lens != query_lens[0] + else: + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] > decode_threshold) + if not require_uniform: + assert torch.all(query_lens[first_prefill:] > decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes @@ -766,6 +798,40 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch +def reshape_query_for_spec_decode(query: torch.Tensor, + batch_size: int) -> torch.Tensor: + """ + Reshapes the query tensor for the specified batch size, so that + it has shape (batch_size, seq_len, num_heads, head_dim). + """ + assert query.dim() == 3, f"query must be 3D, got {query.dim()}D" + total_tokens = query.shape[0] + num_heads = query.shape[1] + head_dim = query.shape[2] + assert total_tokens % batch_size == 0, ( + f"{total_tokens=} is not divisible by {batch_size=}") + seq_len = total_tokens // batch_size + return query.view(batch_size, seq_len, num_heads, head_dim) + + +def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor, + batch_size: int) -> torch.Tensor: + """ + Reshapes the attention output tensor for the specified batch size, so that + the batch_size and seq_len dimensions are combined. + """ + if attn_output.dim() == 3: + # Already in the correct shape + return attn_output + assert attn_output.dim( + ) == 4, f"attn_output must be 4D, got {attn_output.dim()}D" + total_tokens = attn_output.shape[0] * attn_output.shape[1] + assert total_tokens % batch_size == 0, ( + f"{total_tokens=} is not divisible by {batch_size=}") + return attn_output.view(total_tokens, attn_output.shape[2], + attn_output.shape[3]) + + KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ ('logits_indices_padded', Optional[torch.Tensor], None), ('num_logits_indices', int, 0), diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index a6ca33491235..d5a6c4c1db52 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, Optional import torch @@ -197,7 +197,7 @@ def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]: class XFormersAttentionMetadataBuilder( AttentionMetadataBuilder[XFormersAttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__( self, From ab6e3e97931339536bd0981a3db0efe7679d833c Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 18 Sep 2025 16:24:56 +0000 Subject: [PATCH 2/4] unit test for batch splitting Signed-off-by: Benjamin Chislett --- .../v1/attention/test_attention_splitting.py | 111 +++++++++++++++++- 1 file changed, 109 insertions(+), 2 deletions(-) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index c74dbb3ebb17..d7f63b49c706 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -5,11 +5,12 @@ import torch from tests.v1.attention.test_attention_backends import BATCH_SPECS -from tests.v1.attention.utils import create_common_attn_metadata +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm.v1.attention.backends.utils import (UBatchSlice, _make_metadata_with_slice, slice_query_start_locs, - split_attn_metadata) + split_attn_metadata, + split_decodes_and_prefills) @pytest.fixture @@ -155,3 +156,109 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): assert results[1].num_reqs == mid_point assert results[1].num_actual_tokens == mid_point assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) + + +def apply_split_decodes_and_prefills(query_lens: list[int], + decode_threshold: int, + require_uniform: bool): + """Helper function to apply split_decodes_and_prefills and return + the results.""" + device = torch.device("cpu") + seq_lens = [10 * (i + 1) for i in range(len(query_lens))] + common_metadata = create_common_attn_metadata(BatchSpec( + seq_lens=seq_lens, query_lens=query_lens), + block_size=16, + device=device) + return split_decodes_and_prefills(common_metadata, + decode_threshold=decode_threshold, + require_uniform=require_uniform) + + +def test_split_decodes_and_prefills_nonuniform_all_ones(): + query_lens = [1, 1, 1] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 1, False)) + assert num_decodes == 3 + assert num_prefills == 0 + assert num_decode_tokens == 3 + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_nonuniform_all_short_decodes(): + query_lens = [1, 2, 1, 3, 2, 1, 2] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, False)) + assert num_decodes == 7 + assert num_prefills == 0 + assert num_decode_tokens == sum(query_lens) + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_nonuniform_all_prefills(): + query_lens = [4, 5, 6, 7] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, False)) + assert num_decodes == 0 + assert num_prefills == 4 + assert num_decode_tokens == 0 + assert num_prefill_tokens == sum(query_lens) + + +def test_split_decodes_and_prefills_nonuniform_mixed_batch(): + query_lens = [2, 1, 3, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, False)) + assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4 + assert num_prefills == 4 # 5, 6, 7, 8 are all > 4 + assert num_decode_tokens == 10 # 2 + 1 + 3 + 4 + assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8 + + +def test_split_decodes_and_prefills_uniform_all_ones(): + query_lens = [1, 1, 1] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 1, True)) + assert num_decodes == 3 + assert num_prefills == 0 + assert num_decode_tokens == 3 + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_uniform_all_short_decodes(): + query_lens = [2, 2, 1, 3, 2, 1, 2] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True)) + assert num_decodes == 2 + assert num_prefills == 5 + assert num_decode_tokens == 4 + assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2) + + +def test_split_decodes_and_prefills_uniform_all_prefills(): + query_lens = [4, 5, 6, 7] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True)) + assert num_decodes == 0 + assert num_prefills == 4 + assert num_decode_tokens == 0 + assert num_prefill_tokens == sum(query_lens) + + +def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes(): + query_lens = [2, 2, 2, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, True)) + assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform + assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4 + assert num_decode_tokens == 6 # 2 + 2 + 2 + assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8 + + +def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes(): + query_lens = [2, 1, 2, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, True)) + assert num_decodes == 1 # only the first 2 is taken as decode + assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform + assert num_decode_tokens == 2 # only the first 2 + assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens From cfa3273badabd710e18e8e04f9c0db3d8c141e47 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 18 Sep 2025 18:16:10 +0000 Subject: [PATCH 3/4] minor tweak to helper func Signed-off-by: Benjamin Chislett --- vllm/v1/attention/backends/utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index e96fc5ace6e4..d8d5b4ea24ca 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -814,20 +814,18 @@ def reshape_query_for_spec_decode(query: torch.Tensor, return query.view(batch_size, seq_len, num_heads, head_dim) -def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor, - batch_size: int) -> torch.Tensor: +def reshape_attn_output_for_spec_decode( + attn_output: torch.Tensor) -> torch.Tensor: """ - Reshapes the attention output tensor for the specified batch size, so that + Reshapes the attention output tensor, so that the batch_size and seq_len dimensions are combined. """ if attn_output.dim() == 3: # Already in the correct shape return attn_output - assert attn_output.dim( - ) == 4, f"attn_output must be 4D, got {attn_output.dim()}D" + assert attn_output.dim() == 4, \ + f"attn_output must be 4D, got {attn_output.dim()}D" total_tokens = attn_output.shape[0] * attn_output.shape[1] - assert total_tokens % batch_size == 0, ( - f"{total_tokens=} is not divisible by {batch_size=}") return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3]) From 587a6f609d9333d35ac2d4144be2dfc3f9d94f1e Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 22 Sep 2025 20:56:50 +0000 Subject: [PATCH 4/4] remove attention backend requirement for eagle Signed-off-by: Benjamin Chislett --- vllm/v1/spec_decode/eagle.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2a178ddf4877..47b17af3d66b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -118,7 +118,7 @@ def __init__( with_numpy=True) # Determine allowed attention backends once during initialization. - self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] + self.allowed_attn_types: Optional[tuple] = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend @@ -127,9 +127,6 @@ def __init__( AiterFlashAttentionMetadata) rocm_types.append(AiterFlashAttentionMetadata) self.allowed_attn_types = tuple(rocm_types) - else: - self.allowed_attn_types = (FlashAttentionMetadata, - TreeAttentionMetadata) # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree @@ -263,7 +260,8 @@ def propose( draft_token_ids = logits.argmax(dim=-1) - if not isinstance(attn_metadata, self.allowed_attn_types): + if self.allowed_attn_types is not None and \ + not isinstance(attn_metadata, self.allowed_attn_types): raise ValueError( f"Unsupported attention metadata type for speculative " "decoding with num_speculative_tokens > 1: "