Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion tests/v1/attention/test_attention_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm.v1.attention.backends.utils import (UBatchSlice,
_make_metadata_with_slice,
slice_query_start_locs,
split_attn_metadata)
split_attn_metadata,
split_decodes_and_prefills)
from vllm.v1.worker.ubatch_utils import create_ubatch_slices


Expand Down Expand Up @@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))


def apply_split_decodes_and_prefills(query_lens: list[int],
decode_threshold: int,
require_uniform: bool):
"""Helper function to apply split_decodes_and_prefills and return
the results."""
device = torch.device("cpu")
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
common_metadata = create_common_attn_metadata(BatchSpec(
seq_lens=seq_lens, query_lens=query_lens),
block_size=16,
device=device)
return split_decodes_and_prefills(common_metadata,
decode_threshold=decode_threshold,
require_uniform=require_uniform)


def test_split_decodes_and_prefills_nonuniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, False))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0


def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
query_lens = [1, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 7
assert num_prefills == 0
assert num_decode_tokens == sum(query_lens)
assert num_prefill_tokens == 0


def test_split_decodes_and_prefills_nonuniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)


def test_split_decodes_and_prefills_nonuniform_mixed_batch():
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, False))
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8


def test_split_decodes_and_prefills_uniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, True))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0


def test_split_decodes_and_prefills_uniform_all_short_decodes():
query_lens = [2, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 2
assert num_prefills == 5
assert num_decode_tokens == 4
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)


def test_split_decodes_and_prefills_uniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)


def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 6 # 2 + 2 + 2
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8


def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 1 # only the first 2 is taken as decode
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
assert num_decode_tokens == 2 # only the first 2
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens


@pytest.mark.parametrize(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
[
Expand Down
15 changes: 15 additions & 0 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,22 @@ def force_use_trtllm_attention() -> Optional[bool]:
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)


def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
"""Check if the current configuration supports TRTLLM attention."""
has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0)


def use_trtllm_attention(
num_qo_heads: int,
num_kv_heads: int,
num_tokens: int,
max_seq_len: int,
kv_cache_dtype: str,
q_dtype: torch.dtype,
is_prefill: bool,
has_sinks: bool = False,
has_spec: bool = False,
) -> bool:
"""Return ``True`` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention()
Expand All @@ -214,6 +222,12 @@ def use_trtllm_attention(
)
return False

if has_spec and not is_prefill:
# Speculative decoding requires TRTLLM attention for decodes
logger.info_once(
"Using TRTLLM attention (enabled for speculative decoding).")
return True

# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
if has_sinks:
Expand Down Expand Up @@ -391,6 +405,7 @@ def flashinfer_disable_q_quantization() -> bool:
"has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory",
"supports_trtllm_attention",
"can_use_trtllm_attention",
"use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm",
Expand Down
69 changes: 56 additions & 13 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
from vllm.utils.flashinfer import (can_use_trtllm_attention,
flashinfer_disable_q_quantization,
supports_trtllm_attention,
use_trtllm_attention)
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
Expand All @@ -48,6 +49,16 @@

logger = init_logger(__name__)

trtllm_gen_workspace_buffer = None


def _get_trtllm_gen_workspace_buffer():
global trtllm_gen_workspace_buffer
if trtllm_gen_workspace_buffer is None:
trtllm_gen_workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda')
return trtllm_gen_workspace_buffer


@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
Expand Down Expand Up @@ -213,6 +224,7 @@ class FlashInferMetadata:

# For flashinfer trtllm batch decode
max_q_len: int
max_q_len_prefill: int
max_seq_len: int
seq_lens: torch.Tensor
block_table_tensor: torch.Tensor
Expand Down Expand Up @@ -240,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down Expand Up @@ -292,6 +304,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
else:
self.q_data_type = self.model_config.dtype

supports_spec_as_decode = \
can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
self._init_reorder_batch_threshold(1, supports_spec_as_decode)

self._cascade_wrapper = None # Wrapper for cascade attention

# Global hyperparameters shared by all attention layers
Expand Down Expand Up @@ -406,7 +422,8 @@ def build(self,
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
decode_threshold=self.reorder_batch_threshold,
require_uniform=True)

page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len
Expand Down Expand Up @@ -481,20 +498,25 @@ def build(self,
paged_kv_last_page_len_np,
)

uses_spec_reorder = self.reorder_batch_threshold > 1
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_prefill_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
has_sinks=self.has_sinks)
is_prefill=True,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder)
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_decode_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
has_sinks=self.has_sinks)
is_prefill=False,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder)
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
Expand All @@ -511,6 +533,7 @@ def build(self,
q_data_type=self.q_data_type,
slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len,
max_q_len_prefill=max_q_len,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
Expand Down Expand Up @@ -567,6 +590,15 @@ def build(self,
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
prefill_start]
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]

# Recompute max_q_len for the slice of requests we are using
# for prefills. This can be different from max_q_len when
# we have a non-uniform batch with some short decodes offloaded
# to the prefill pathway
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
attn_metadata.max_q_len_prefill = \
int(query_lens_prefill.max().item())

if not attn_metadata.prefill_use_trtllm:
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
Expand Down Expand Up @@ -597,7 +629,7 @@ def build(self,
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens = (
self.vllm_config.pad_for_cudagraph(num_decodes))
self.vllm_config.pad_for_cudagraph(num_decode_tokens))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
Expand All @@ -611,7 +643,7 @@ def build(self,
num_decodes:num_input_tokens].fill_(1)

else:
num_input_tokens = num_decodes
num_input_tokens = num_decode_tokens

attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph)
Expand Down Expand Up @@ -832,6 +864,9 @@ def forward(
output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
return output

# When using spec decoding, num_decodes can be < num_decode_tokens
# because some decode requests may have more than one query token.
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens

Expand Down Expand Up @@ -862,10 +897,10 @@ def forward(
else:
# prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous()
workspace_buffer = prefill_wrapper._float_workspace_buffer
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.block_table_tensor[
num_decode_tokens:]
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
num_decodes:]
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
Comment on lines +902 to +903
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably naive q: Can there be cases in normal decode where num_decodes < num_decode_tokens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, reorder_batch_size == 1 so num_decodes == num_decode_tokens.

However, we're using a padded-batch speculative decoding implementation where we can use the trtllm-gen batch_decode kernel for a batch of requests as long as they all have the same q_len, which can be larger than 1.

So we need to fix a bunch of cases like this one, where we can have max_q_len * num_decodes tokens in the decode pathway


# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
Expand Down Expand Up @@ -909,7 +944,7 @@ def forward(
workspace_buffer=workspace_buffer,
block_tables=mock_block_table,
seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len,
max_q_len=attn_metadata.max_q_len_prefill,
max_kv_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
Expand Down Expand Up @@ -943,7 +978,7 @@ def forward(
else:
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_decode = attn_metadata.\
block_table_tensor[:num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
Expand All @@ -966,6 +1001,14 @@ def forward(
assert self.o_sf_scale is None
out = output[:num_decode_tokens]

if num_decode_tokens % attn_metadata.num_decodes != 0:
# This gets triggered when the dummy_run forces
# attention to be initialized with q_len = 0
q_len_per_req = 1
else:
q_len_per_req = \
num_decode_tokens // attn_metadata.num_decodes

trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
Expand All @@ -979,7 +1022,7 @@ def forward(
sinks=self.sinks,
o_sf_scale=self.o_sf_scale,
out=out,
)
q_len_per_req=q_len_per_req)
return output_padded


Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional

import torch

Expand Down Expand Up @@ -62,7 +62,7 @@ class GDNAttentionMetadataBuilder(

cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: ClassVar[int] = 1
reorder_batch_threshold: int = 1

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand All @@ -76,7 +76,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc]
self._init_reorder_batch_threshold(1, self.use_spec_decode)

self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
Expand Down
Loading