diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 2454f85342eb..780a0d6b5c0e 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -46,7 +46,10 @@ class BackendConfig: # FA3 on Hopper "FA3": BackendConfig(name="FA3", - env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, + env_vars={ + "VLLM_FLASH_ATTN_VERSION": "3", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, comp_config={ "cudagraph_mode": "FULL", }, @@ -66,6 +69,7 @@ class BackendConfig: BackendConfig(name="FlashAttentionMLA", env_vars={ "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", }, comp_config={ "cudagraph_mode": "FULL_DECODE_ONLY", @@ -89,7 +93,10 @@ class BackendConfig: # FA2 "FA2": BackendConfig(name="FA2", - env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, + env_vars={ + "VLLM_FLASH_ATTN_VERSION": "2", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, comp_config={ "cudagraph_mode": "FULL", }), diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 25e01806f495..1ae9185fafbd 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -47,7 +47,10 @@ class BackendConfig: # FA3 on Hopper "FA3": BackendConfig(name="FA3", - env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, + env_vars={ + "VLLM_FLASH_ATTN_VERSION": "3", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, comp_config={ "cudagraph_mode": "FULL", }, @@ -67,6 +70,7 @@ class BackendConfig: BackendConfig(name="FlashAttentionMLA", env_vars={ "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", }, comp_config={ "cudagraph_mode": "FULL_DECODE_ONLY", @@ -75,7 +79,10 @@ class BackendConfig: # FA2 "FA2": BackendConfig(name="FA2", - env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, + env_vars={ + "VLLM_FLASH_ATTN_VERSION": "2", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", }), diff --git a/vllm/envs.py b/vllm/envs.py index 19e2f8635275..1f392160aa44 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -118,6 +118,7 @@ VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_MLA_DISABLE: bool = False + VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16 VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None @@ -949,6 +950,12 @@ def get_vllm_port() -> Optional[int]: "VLLM_MLA_DISABLE": lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), + # If set, vLLM will pick up the provided Flash Attention MLA + # max number splits for cuda graph decode + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": + lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", + "16")), + # Number of GPUs per worker in Ray, if it is set to be a fraction, # it allows ray to schedule multiple actors on a single GPU, # so that users can colocate other actors on the same GPUs as vLLM. @@ -1382,6 +1389,7 @@ def compute_hash() -> str: environment_variables_to_hash = [ "VLLM_PP_LAYER_PARTITION", "VLLM_MLA_DISABLE", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "VLLM_USE_TRITON_FLASH_ATTN", "VLLM_USE_TRITON_AWQ", "VLLM_DP_RANK", diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 20f1904b3be6..d564cf9988ea 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -8,6 +8,7 @@ import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) @@ -33,9 +34,6 @@ logger = init_logger(__name__) -# NOTE(woosuk): This is an arbitrary number. Tune it if needed. -_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 - class FlashAttentionBackend(AttentionBackend): @@ -215,7 +213,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 472095e13615..4ad9a13b61d8 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -6,6 +6,7 @@ import torch +from vllm import envs from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, @@ -24,10 +25,6 @@ logger = init_logger(__name__) -# NOTE(matt): This is an arbitrary number, copied from -# woosuk's implementation in standard FlashAttention backend -_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 - class FlashAttnMLABackend(MLACommonBackend): @@ -97,7 +94,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) # TODO(lucas): Until we add support for the DCP custom masking we need # to restrict decodes to q_len == 1 when DCP is enabled.