Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions tests/v1/attention/test_attention_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import torch

from tests.v1.attention.test_attention_backends import BATCH_SPECS
from tests.v1.attention.utils import create_common_attn_metadata
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.v1.attention.backends.utils import (UBatchSlice,
_make_metadata_with_slice,
slice_query_start_locs,
split_attn_metadata)
split_attn_metadata,
split_decodes_and_prefills)


@pytest.fixture
Expand Down Expand Up @@ -155,3 +156,109 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
assert results[1].num_reqs == mid_point
assert results[1].num_actual_tokens == mid_point
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))


def apply_split_decodes_and_prefills(query_lens: list[int],
decode_threshold: int,
require_uniform: bool):
"""Helper function to apply split_decodes_and_prefills and return
the results."""
device = torch.device("cpu")
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
common_metadata = create_common_attn_metadata(BatchSpec(
seq_lens=seq_lens, query_lens=query_lens),
block_size=16,
device=device)
return split_decodes_and_prefills(common_metadata,
decode_threshold=decode_threshold,
require_uniform=require_uniform)


def test_split_decodes_and_prefills_nonuniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, False))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0


def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
query_lens = [1, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 7
assert num_prefills == 0
assert num_decode_tokens == sum(query_lens)
assert num_prefill_tokens == 0


def test_split_decodes_and_prefills_nonuniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)


def test_split_decodes_and_prefills_nonuniform_mixed_batch():
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, False))
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8


def test_split_decodes_and_prefills_uniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, True))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0


def test_split_decodes_and_prefills_uniform_all_short_decodes():
query_lens = [2, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 2
assert num_prefills == 5
assert num_decode_tokens == 4
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)


def test_split_decodes_and_prefills_uniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)


def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 6 # 2 + 2 + 2
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8


def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 1 # only the first 2 is taken as decode
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
assert num_decode_tokens == 2 # only the first 2
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional

import torch

Expand Down Expand Up @@ -56,7 +56,7 @@ class GDNAttentionMetadataBuilder(

cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand All @@ -70,7 +70,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc]
self._init_reorder_batch_threshold(1, self.use_spec_decode)

self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/attention/backends/linear_attn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar

import torch

Expand Down Expand Up @@ -35,7 +34,7 @@ class LinearAttentionMetadata:
class LinearAttentionMetadataBuilder(
AttentionMetadataBuilder[LinearAttentionMetadata]):

reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
import functools
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import ClassVar, Generic, Optional, TypeVar, Union
from typing import Generic, Optional, TypeVar, Union

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -433,7 +433,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1

def __init__(self,
kv_cache_spec: AttentionSpec,
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class FlashAttnMLAMetadataBuilder(
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: ClassVar[int] = 512
reorder_batch_threshold: int = 512

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],

# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self.__class__.reorder_batch_threshold = 1 \
self.reorder_batch_threshold = 1 \
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold

def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/short_conv_attn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional

import torch

Expand Down Expand Up @@ -41,7 +41,7 @@ class ShortConvAttentionMetadata:
class ShortConvAttentionMetadataBuilder(
AttentionMetadataBuilder[ShortConvAttentionMetadata]):

reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down
78 changes: 71 additions & 7 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[Optional[int]] = None
reorder_batch_threshold: Optional[int] = None
# Does this backend/builder support issuing decodes with (uniform) qlen > 1?
supports_spec_as_decode: bool = False

@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
Expand All @@ -205,6 +207,23 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
self.vllm_config = vllm_config
self.device = device

def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int = 1,
supports_spec_as_decode: bool = False) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
self.supports_spec_as_decode = supports_spec_as_decode
if self.reorder_batch_threshold is not None \
and self.supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (speculative_config is not None
and speculative_config.num_speculative_tokens is not None):
self.reorder_batch_threshold = \
1 + speculative_config.num_speculative_tokens

@abstractmethod
def build(self,
common_prefix_len: int,
Expand Down Expand Up @@ -662,9 +681,9 @@ def subclass_attention_backend(


def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int]:
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
require_uniform: bool = False) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Expand All @@ -673,6 +692,9 @@ def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
require_uniform: If True, requires that all decode requests have the
same query length. When set, some queries may be considered prefills
even if they are <= decode_threshold, in order to ensure uniformity.

Returns:
num_decodes: The number of decode requests.
Expand All @@ -685,16 +707,26 @@ def split_decodes_and_prefills(
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu

if max_query_len <= decode_threshold:
if max_query_len <= decode_threshold and \
(not require_uniform or decode_threshold <= 1):
return num_reqs, 0, num_tokens, 0

query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill = query_lens > decode_threshold
if query_lens[0].item() > decode_threshold:
# first request is not decode, so no decode requests
return 0, num_reqs, 0, num_tokens

if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold

if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0

first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > decode_threshold)
if not require_uniform:
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
Expand Down Expand Up @@ -766,6 +798,38 @@ def reorder_batch_to_split_decodes_and_prefills(
return modified_batch


def reshape_query_for_spec_decode(query: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these not used yet? should we just include them in the follow-up once they are actually used? or maybe we should add FlashMLA support in this PR? Just so everything is used (and tested since we can do a FlashMLA + MTP lm_eval run)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is worth committing now and using in subsequent PRs mostly because it will be used by FlashMLA and also FlashInfer-MLA and maybe more. Merging here as a helper means that all the downstream PRs can reuse the same code from main instead of duplicating it in each.

But I don't feel particularly strongly about this, and can remove if you think it's better to add separately.

batch_size: int) -> torch.Tensor:
"""
Reshapes the query tensor for the specified batch size, so that
it has shape (batch_size, seq_len, num_heads, head_dim).
"""
assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
total_tokens = query.shape[0]
num_heads = query.shape[1]
head_dim = query.shape[2]
assert total_tokens % batch_size == 0, (
f"{total_tokens=} is not divisible by {batch_size=}")
seq_len = total_tokens // batch_size
return query.view(batch_size, seq_len, num_heads, head_dim)


def reshape_attn_output_for_spec_decode(
attn_output: torch.Tensor) -> torch.Tensor:
"""
Reshapes the attention output tensor, so that
the batch_size and seq_len dimensions are combined.
"""
if attn_output.dim() == 3:
# Already in the correct shape
return attn_output
assert attn_output.dim() == 4, \
f"attn_output must be 4D, got {attn_output.dim()}D"
total_tokens = attn_output.shape[0] * attn_output.shape[1]
return attn_output.view(total_tokens, attn_output.shape[2],
attn_output.shape[3])


KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""

from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Optional

import torch

Expand Down Expand Up @@ -197,7 +197,7 @@ def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]):

reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1

def __init__(
self,
Expand Down
Loading
Loading