Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/design/cuda_graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ The following table lists backends that support full CUDA Graphs at the time of
| FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good |
| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches |
| AITER FlashAttention | `UNIFORM_BATCH`| |
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell |
| FlashMLA | `UNIFORM_BATCH` | |
| FlashInferMLA | `UNIFORM_BATCH` | |
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/layers/chunked_local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_chunked_local_attention_backend(
underlying_builder = underlying_attn_backend.get_builder_cls()

class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER

def build(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support = (
_cudagraph_support = (
AttentionCGSupport.ALWAYS
if get_flash_attn_version() == 3
else AttentionCGSupport.UNIFORM_BATCH
Expand Down
38 changes: 23 additions & 15 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from typing_extensions import override

from vllm.attention.backends.abstract import (
AttentionBackend,
Expand Down Expand Up @@ -274,10 +275,6 @@ class FlashInferMetadata:


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

reorder_batch_threshold: int = 1

def __init__(
Expand Down Expand Up @@ -355,6 +352,9 @@ def __init__(
else:
self.q_data_type = self.model_config.dtype

# Prefer TRTLLM attention for decoding in all cases.
# This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
self.use_trtllm_decode_attention = can_use_trtllm
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)

self._cascade_wrapper = None # Wrapper for cascade attention
Expand Down Expand Up @@ -412,6 +412,24 @@ def __init__(
"passing --block-size 32 or --block-size 64."
)

@classmethod
@override
def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
has_trtllm_support = can_use_trtllm_attention(
num_qo_heads=vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
),
num_kv_heads=kv_cache_spec.num_kv_heads,
)
if has_trtllm_support:
return AttentionCGSupport.UNIFORM_BATCH
else:
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
Expand Down Expand Up @@ -573,17 +591,7 @@ def build(
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
decode_use_trtllm = use_trtllm_attention(
self.num_qo_heads,
self.num_kv_heads,
num_decode_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
is_prefill=False,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
decode_use_trtllm = self.use_trtllm_decode_attention

if not (prefill_use_trtllm and decode_use_trtllm):
if self.has_sinks:
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class GDNAttentionMetadata:


class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: int = 1

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: int = 1
cudagraph_support: ClassVar[AttentionCGSupport] = (
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[AttentionCGSupport] = (
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):


class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
reorder_batch_threshold: int = 512 # process small prefills with decode pathway

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM


Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
# ^ TODO(matt): tune this
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def triton_convert_req_index_to_global_index(

@dataclass
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def split_prefill_chunks(


class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
cudagraph_support: ClassVar[AttentionCGSupport] = (
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support: ClassVar[AttentionCGSupport] = (
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class AiterFlashAttentionMetadata:
class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
):
cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: int = 1

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class RocmAttentionMetadata:


class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class TritonAttentionMetadata:


class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS

def __init__(
self,
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ class AttentionCGSupport(enum.Enum):

class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention (default: no).
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
# Do not access directly. Call get_cudagraph_support() instead.
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
Expand All @@ -263,6 +264,15 @@ def __init__(
self.vllm_config = vllm_config
self.device = device

@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
"""Get the cudagraph support level of this builder class."""
return cls._cudagraph_support

def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int = 1,
Expand Down
31 changes: 21 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4136,14 +4136,16 @@ def create_attn_groups(
return attn_groups

attention_backend_maps = []
attention_backend_set: set[type[AttentionBackend]] = set()
attention_backend_list = []
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
attention_backend_maps.append(attn_backends[0])
attention_backend_set.update(attn_backends[1])
attention_backend_list.append(attn_backends[1])

# Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode(attention_backend_set)
self._check_and_update_cudagraph_mode(
attention_backend_list, kv_cache_config.kv_cache_groups
)

for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
Expand Down Expand Up @@ -4172,22 +4174,31 @@ def initialize_metadata_builders(
self.calculate_reorder_batch_threshold()

def _check_and_update_cudagraph_mode(
self, attention_backends: set[type[AttentionBackend]]
self,
attention_backends: list[set[type[AttentionBackend]]],
kv_cache_groups: list[KVCacheGroupSpec],
) -> None:
"""
Resolve the cudagraph_mode when there are multiple attention
backends with potential conflicting CUDA graph support.
groups with potential conflicting CUDA graph support.
Then initialize the cudagraph_dispatcher based on the resolved
cudagraph_mode.
"""
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_backend_name = None

for attn_backend in attention_backends:
builder_cls = attn_backend.get_builder_cls()
if builder_cls.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder_cls.cudagraph_support
min_cg_backend_name = attn_backend.__name__
for attn_backend_set, kv_cache_group in zip(
attention_backends, kv_cache_groups
):
for attn_backend in attn_backend_set:
builder_cls = attn_backend.get_builder_cls()

cg_support = builder_cls.get_cudagraph_support(
self.vllm_config, kv_cache_group.kv_cache_spec
)
if cg_support.value < min_cg_support.value:
min_cg_support = cg_support
min_cg_backend_name = attn_backend.__name__
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
# check cudagraph for mixed batch is supported
Expand Down
Loading