Skip to content
Merged
11 changes: 9 additions & 2 deletions tests/compile/piecewise/test_full_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class BackendConfig:
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
Expand All @@ -66,6 +69,7 @@ class BackendConfig:
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
Expand All @@ -89,7 +93,10 @@ class BackendConfig:
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
}),
Expand Down
11 changes: 9 additions & 2 deletions tests/v1/cudagraph/test_cudagraph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ class BackendConfig:
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
Expand All @@ -67,6 +70,7 @@ class BackendConfig:
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
Expand All @@ -75,7 +79,10 @@ class BackendConfig:
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For consistency with other optional integer environment variables like VLLM_FLASH_ATTN_VERSION, it's better to define this as Optional[int] and handle the default value in the consumer module (flashattn_mla.py). This makes the intent clearer that the variable is optional and has a fallback. This change is related to another suggested change for the lambda function of this environment variable.

Suggested change
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: Optional[int] = None

VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
Expand Down Expand Up @@ -949,6 +950,12 @@ def get_vllm_port() -> Optional[int]:
"VLLM_MLA_DISABLE":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),

# If set, vLLM will pick up the provided Flash Attention MLA
# max number splits for cuda graph decode
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"16")),

# Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU,
# so that users can colocate other actors on the same GPUs as vLLM.
Expand Down Expand Up @@ -1382,6 +1389,7 @@ def compute_hash() -> str:
environment_variables_to_hash = [
"VLLM_PP_LAYER_PARTITION",
"VLLM_MLA_DISABLE",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK",
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
Expand All @@ -33,9 +34,6 @@

logger = init_logger(__name__)

# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16


class FlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -215,7 +213,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
self.max_num_splits = (
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)

# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
Expand Down
8 changes: 3 additions & 5 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch

from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
Expand All @@ -24,10 +25,6 @@

logger = init_logger(__name__)

# NOTE(matt): This is an arbitrary number, copied from
# woosuk's implementation in standard FlashAttention backend
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16


class FlashAttnMLABackend(MLACommonBackend):

Expand Down Expand Up @@ -97,7 +94,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
self.max_num_splits = (
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)

# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
Expand Down