diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 7d7a46910be8..2fd11415d490 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -9,7 +9,8 @@ from vllm.v1.attention.backends.utils import (UBatchSlice, _make_metadata_with_slice, slice_query_start_locs, - split_attn_metadata) + split_attn_metadata, + split_decodes_and_prefills) from vllm.v1.worker.ubatch_utils import create_ubatch_slices @@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) +def apply_split_decodes_and_prefills(query_lens: list[int], + decode_threshold: int, + require_uniform: bool): + """Helper function to apply split_decodes_and_prefills and return + the results.""" + device = torch.device("cpu") + seq_lens = [10 * (i + 1) for i in range(len(query_lens))] + common_metadata = create_common_attn_metadata(BatchSpec( + seq_lens=seq_lens, query_lens=query_lens), + block_size=16, + device=device) + return split_decodes_and_prefills(common_metadata, + decode_threshold=decode_threshold, + require_uniform=require_uniform) + + +def test_split_decodes_and_prefills_nonuniform_all_ones(): + query_lens = [1, 1, 1] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 1, False)) + assert num_decodes == 3 + assert num_prefills == 0 + assert num_decode_tokens == 3 + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_nonuniform_all_short_decodes(): + query_lens = [1, 2, 1, 3, 2, 1, 2] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, False)) + assert num_decodes == 7 + assert num_prefills == 0 + assert num_decode_tokens == sum(query_lens) + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_nonuniform_all_prefills(): + query_lens = [4, 5, 6, 7] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, False)) + assert num_decodes == 0 + assert num_prefills == 4 + assert num_decode_tokens == 0 + assert num_prefill_tokens == sum(query_lens) + + +def test_split_decodes_and_prefills_nonuniform_mixed_batch(): + query_lens = [2, 1, 3, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, False)) + assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4 + assert num_prefills == 4 # 5, 6, 7, 8 are all > 4 + assert num_decode_tokens == 10 # 2 + 1 + 3 + 4 + assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8 + + +def test_split_decodes_and_prefills_uniform_all_ones(): + query_lens = [1, 1, 1] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 1, True)) + assert num_decodes == 3 + assert num_prefills == 0 + assert num_decode_tokens == 3 + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_uniform_all_short_decodes(): + query_lens = [2, 2, 1, 3, 2, 1, 2] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True)) + assert num_decodes == 2 + assert num_prefills == 5 + assert num_decode_tokens == 4 + assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2) + + +def test_split_decodes_and_prefills_uniform_all_prefills(): + query_lens = [4, 5, 6, 7] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True)) + assert num_decodes == 0 + assert num_prefills == 4 + assert num_decode_tokens == 0 + assert num_prefill_tokens == sum(query_lens) + + +def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes(): + query_lens = [2, 2, 2, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, True)) + assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform + assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4 + assert num_decode_tokens == 6 # 2 + 2 + 2 + assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8 + + +def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes(): + query_lens = [2, 1, 2, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, True)) + assert num_decodes == 1 # only the first 2 is taken as decode + assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform + assert num_decode_tokens == 2 # only the first 2 + assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens + + @pytest.mark.parametrize( "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", [ diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 2179bddae243..ebc7a56ff906 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -181,6 +181,12 @@ def force_use_trtllm_attention() -> Optional[bool]: return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) +def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: + """Check if the current configuration supports TRTLLM attention.""" + has_trtllm = supports_trtllm_attention() + return has_trtllm and (num_qo_heads % num_kv_heads == 0) + + def use_trtllm_attention( num_qo_heads: int, num_kv_heads: int, @@ -188,7 +194,9 @@ def use_trtllm_attention( max_seq_len: int, kv_cache_dtype: str, q_dtype: torch.dtype, + is_prefill: bool, has_sinks: bool = False, + has_spec: bool = False, ) -> bool: """Return ``True`` if TRTLLM attention is used.""" force_use_trtllm = force_use_trtllm_attention() @@ -214,6 +222,12 @@ def use_trtllm_attention( ) return False + if has_spec and not is_prefill: + # Speculative decoding requires TRTLLM attention for decodes + logger.info_once( + "Using TRTLLM attention (enabled for speculative decoding).") + return True + # Must use TRTLLM attention if query is FP8 quantized if q_dtype == current_platform.fp8_dtype(): if has_sinks: @@ -391,6 +405,7 @@ def flashinfer_disable_q_quantization() -> bool: "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", "supports_trtllm_attention", + "can_use_trtllm_attention", "use_trtllm_attention", "flashinfer_disable_q_quantization", "flashinfer_scaled_fp4_mm", diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cb092aa74e7f..891108f961b5 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -25,7 +25,8 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (flashinfer_disable_q_quantization, +from vllm.utils.flashinfer import (can_use_trtllm_attention, + flashinfer_disable_q_quantization, supports_trtllm_attention, use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention @@ -48,6 +49,16 @@ logger = init_logger(__name__) +trtllm_gen_workspace_buffer = None + + +def _get_trtllm_gen_workspace_buffer(): + global trtllm_gen_workspace_buffer + if trtllm_gen_workspace_buffer is None: + trtllm_gen_workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda') + return trtllm_gen_workspace_buffer + @triton.jit def _trtllm_prefill_attn_kvfp8_dequant( @@ -213,6 +224,7 @@ class FlashInferMetadata: # For flashinfer trtllm batch decode max_q_len: int + max_q_len_prefill: int max_seq_len: int seq_lens: torch.Tensor block_table_tensor: torch.Tensor @@ -240,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -292,6 +304,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], else: self.q_data_type = self.model_config.dtype + supports_spec_as_decode = \ + can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + self._init_reorder_batch_threshold(1, supports_spec_as_decode) + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers @@ -406,7 +422,8 @@ def build(self, num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + decode_threshold=self.reorder_batch_threshold, + require_uniform=True) page_size = self.page_size max_q_len = common_attn_metadata.max_query_len @@ -481,20 +498,25 @@ def build(self, paged_kv_last_page_len_np, ) + uses_spec_reorder = self.reorder_batch_threshold > 1 prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_prefill_tokens, max_seq_len, self.cache_dtype, self.q_data_type, - has_sinks=self.has_sinks) + is_prefill=True, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder) decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_decode_tokens, max_seq_len, self.cache_dtype, self.q_data_type, - has_sinks=self.has_sinks) + is_prefill=False, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder) if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): raise NotImplementedError( "FlashInfer backend currently does not support attention " @@ -511,6 +533,7 @@ def build(self, q_data_type=self.q_data_type, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, + max_q_len_prefill=max_q_len, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table_tensor=block_table_tensor, @@ -567,6 +590,15 @@ def build(self, qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ prefill_start] paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] + + # Recompute max_q_len for the slice of requests we are using + # for prefills. This can be different from max_q_len when + # we have a non-uniform batch with some short decodes offloaded + # to the prefill pathway + query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] + attn_metadata.max_q_len_prefill = \ + int(query_lens_prefill.max().item()) + if not attn_metadata.prefill_use_trtllm: attn_metadata.prefill_wrapper.plan( qo_indptr_cpu, @@ -597,7 +629,7 @@ def build(self, num_decodes <= self._decode_cudagraph_max_bs) if use_cudagraph: num_input_tokens = ( - self.vllm_config.pad_for_cudagraph(num_decodes)) + self.vllm_config.pad_for_cudagraph(num_decode_tokens)) # Carefully fulfill the padding region with reasonable value # on cpu. # Make sure paged_kv_indptr_cpu is not decreasing @@ -611,7 +643,7 @@ def build(self, num_decodes:num_input_tokens].fill_(1) else: - num_input_tokens = num_decodes + num_input_tokens = num_decode_tokens attn_metadata.decode_wrapper = self._get_decode_wrapper( num_input_tokens, use_cudagraph) @@ -832,6 +864,9 @@ def forward( output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) return output + # When using spec decoding, num_decodes can be < num_decode_tokens + # because some decode requests may have more than one query token. + num_decodes = attn_metadata.num_decodes num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -862,10 +897,10 @@ def forward( else: # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() - workspace_buffer = prefill_wrapper._float_workspace_buffer + workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_prefill = attn_metadata.block_table_tensor[ - num_decode_tokens:] - seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] + num_decodes:] + seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND assert get_kv_cache_layout() == "HND" @@ -909,7 +944,7 @@ def forward( workspace_buffer=workspace_buffer, block_tables=mock_block_table, seq_lens=seq_lens_prefill, - max_q_len=attn_metadata.max_q_len, + max_q_len=attn_metadata.max_q_len_prefill, max_kv_len=attn_metadata.max_seq_len, bmm1_scale=self.bmm1_scale, bmm2_scale=self.bmm2_scale, @@ -943,7 +978,7 @@ def forward( else: # decode_query may be non-contiguous decode_query = decode_query.contiguous() - workspace_buffer = decode_wrapper._float_workspace_buffer + workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_decode = attn_metadata.\ block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] @@ -966,6 +1001,14 @@ def forward( assert self.o_sf_scale is None out = output[:num_decode_tokens] + if num_decode_tokens % attn_metadata.num_decodes != 0: + # This gets triggered when the dummy_run forces + # attention to be initialized with q_len = 0 + q_len_per_req = 1 + else: + q_len_per_req = \ + num_decode_tokens // attn_metadata.num_decodes + trtllm_batch_decode_with_kv_cache( query=decode_query, kv_cache=kv_cache_permute, @@ -979,7 +1022,7 @@ def forward( sinks=self.sinks, o_sf_scale=self.o_sf_scale, out=out, - ) + q_len_per_req=q_len_per_req) return output_padded diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 06a87a4a3c8b..843958bc79de 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Backend for GatedDeltaNet attention.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -62,7 +62,7 @@ class GDNAttentionMetadataBuilder( cudagraph_support = AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -76,7 +76,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], else: self.num_spec = 0 self.use_spec_decode = self.num_spec > 0 - self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc] + self._init_reorder_batch_threshold(1, self.use_spec_decode) self.use_full_cuda_graph = \ self.compilation_config.cudagraph_mode.has_full_cudagraphs() diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 3ff201d83a79..0dc62d668020 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar import torch @@ -35,7 +34,7 @@ class LinearAttentionMetadata: class LinearAttentionMetadataBuilder( AttentionMetadataBuilder[LinearAttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 9970331a6042..ef342ce421ae 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -16,7 +16,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a177117a50bd..11713356c0b7 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,7 +190,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import ClassVar, Generic, Optional, TypeVar, Union +from typing import Generic, Optional, TypeVar, Union import torch from tqdm import tqdm @@ -434,7 +434,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 4ad9a13b61d8..652b1cdb6b76 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -64,7 +64,7 @@ class FlashAttnMLAMetadataBuilder( cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: ClassVar[int] = 512 + reorder_batch_threshold: int = 512 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -99,7 +99,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # TODO(lucas): Until we add support for the DCP custom masking we need # to restrict decodes to q_len == 1 when DCP is enabled. - self.__class__.reorder_batch_threshold = 1 \ + self.reorder_batch_threshold = 1 \ if get_dcp_group().world_size > 1 else self.reorder_batch_threshold def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index 428e40965979..df7f0d2310ab 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -41,7 +41,7 @@ class ShortConvAttentionMetadata: class ShortConvAttentionMetadataBuilder( AttentionMetadataBuilder[ShortConvAttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index f837439f953e..0c6e0dfefd8a 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -236,7 +236,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. - reorder_batch_threshold: ClassVar[Optional[int]] = None + reorder_batch_threshold: Optional[int] = None @abstractmethod def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], @@ -246,6 +246,22 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.vllm_config = vllm_config self.device = device + def _init_reorder_batch_threshold( + self, + reorder_batch_threshold: int = 1, + supports_spec_as_decode: bool = False) -> None: + self.reorder_batch_threshold = reorder_batch_threshold + if self.reorder_batch_threshold is not None \ + and supports_spec_as_decode: + # If the backend supports spec-as-decode kernels, then we can set + # the reorder_batch_threshold based on the number of speculative + # tokens from the config. + speculative_config = self.vllm_config.speculative_config + if (speculative_config is not None + and speculative_config.num_speculative_tokens is not None): + self.reorder_batch_threshold = \ + 1 + speculative_config.num_speculative_tokens + @abstractmethod def build(self, common_prefix_len: int, @@ -703,9 +719,9 @@ def subclass_attention_backend( def split_decodes_and_prefills( - common_attn_metadata: CommonAttentionMetadata, - decode_threshold: int = 1, -) -> tuple[int, int, int, int]: + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, + require_uniform: bool = False) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. @@ -714,6 +730,9 @@ def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata object containing the batch metadata. decode_threshold: The maximum query length to be considered a decode. + require_uniform: If True, requires that all decode requests have the + same query length. When set, some queries may be considered prefills + even if they are <= decode_threshold, in order to ensure uniformity. Returns: num_decodes: The number of decode requests. @@ -726,11 +745,20 @@ def split_decodes_and_prefills( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - if max_query_len <= decode_threshold: + if max_query_len <= decode_threshold and \ + (not require_uniform or decode_threshold <= 1): return num_reqs, 0, num_tokens, 0 query_lens = query_start_loc[1:] - query_start_loc[:-1] - is_prefill = query_lens > decode_threshold + if query_lens[0].item() > decode_threshold: + # first request is not decode, so no decode requests + return 0, num_reqs, 0, num_tokens + + if require_uniform: + is_prefill = query_lens != query_lens[0] + else: + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 @@ -806,6 +834,38 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch +def reshape_query_for_spec_decode(query: torch.Tensor, + batch_size: int) -> torch.Tensor: + """ + Reshapes the query tensor for the specified batch size, so that + it has shape (batch_size, seq_len, num_heads, head_dim). + """ + assert query.dim() == 3, f"query must be 3D, got {query.dim()}D" + total_tokens = query.shape[0] + num_heads = query.shape[1] + head_dim = query.shape[2] + assert total_tokens % batch_size == 0, ( + f"{total_tokens=} is not divisible by {batch_size=}") + seq_len = total_tokens // batch_size + return query.view(batch_size, seq_len, num_heads, head_dim) + + +def reshape_attn_output_for_spec_decode( + attn_output: torch.Tensor) -> torch.Tensor: + """ + Reshapes the attention output tensor, so that + the batch_size and seq_len dimensions are combined. + """ + if attn_output.dim() == 3: + # Already in the correct shape + return attn_output + assert attn_output.dim() == 4, \ + f"attn_output must be 4D, got {attn_output.dim()}D" + total_tokens = attn_output.shape[0] * attn_output.shape[1] + return attn_output.view(total_tokens, attn_output.shape[2], + attn_output.shape[3]) + + KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ ('logits_indices_padded', Optional[torch.Tensor], None), ('num_logits_indices', int, 0), diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index a6ca33491235..d5a6c4c1db52 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, Optional import torch @@ -197,7 +197,7 @@ def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]: class XFormersAttentionMetadataBuilder( AttentionMetadataBuilder[XFormersAttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__( self, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a9e0a38fe341..5cae7df70470 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,7 +3,7 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import Optional, Protocol +from typing import Optional import numpy as np import torch @@ -37,17 +37,6 @@ PADDING_SLOT_ID = -1 -class EagleAttentionMetadata(Protocol): - # Required attributes - num_actual_tokens: int - max_query_len: int - query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - block_table: torch.Tensor - slot_mapping: torch.Tensor - - class EagleProposer: def __init__( @@ -120,7 +109,7 @@ def __init__( with_numpy=True) # Determine allowed attention backends once during initialization. - self.allowed_attn_types: tuple[type, ...] + self.allowed_attn_types: Optional[tuple] = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend @@ -129,9 +118,6 @@ def __init__( AiterFlashAttentionMetadata) rocm_types.append(AiterFlashAttentionMetadata) self.allowed_attn_types = tuple(rocm_types) - else: - self.allowed_attn_types = (FlashAttentionMetadata, - TreeAttentionMetadata) # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree @@ -266,7 +252,8 @@ def propose( draft_token_ids = logits.argmax(dim=-1) - if not isinstance(attn_metadata, self.allowed_attn_types): + if self.allowed_attn_types is not None and \ + not isinstance(attn_metadata, self.allowed_attn_types): raise ValueError( f"Unsupported attention metadata type for speculative " "decoding with num_speculative_tokens > 1: "