diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index bf4d2179af5f..7fb08e5780f5 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -7,6 +7,7 @@ from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.platforms import current_platform +from vllm.utils.math_utils import next_power_of_2 NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] @@ -22,6 +23,10 @@ # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] +# 0: use 2D kernel for decode +# 8: use 3D kernel for decode +SEQ_THRESHOLD_3D_VALUES = [0, 8] + def ref_paged_attn( query: torch.Tensor, @@ -92,6 +97,7 @@ def ref_paged_attn( @pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) +@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES) @torch.inference_mode() def test_triton_unified_attn( seq_lens: list[tuple[int, int]], @@ -103,6 +109,7 @@ def test_triton_unified_attn( soft_cap: float | None, num_blocks: int, q_dtype: torch.dtype | None, + seq_threshold_3D: int, ) -> None: torch.set_default_device("cuda") @@ -152,6 +159,21 @@ def test_triton_unified_attn( k_descale = torch.rand(scale_shape, dtype=torch.float32) v_descale = torch.rand(scale_shape, dtype=torch.float32) + num_par_softmax_segments = 16 + head_size_padded = next_power_of_2(head_size) + softmax_segm_output = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded), + dtype=torch.float32, + ) + softmax_segm_max = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + softmax_segm_expsum = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + unified_attention( q=maybe_quantized_query, k=maybe_quantized_key_cache, @@ -169,6 +191,11 @@ def test_triton_unified_attn( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, ) ref_output = ref_paged_attn( diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 565be1c39bec..8a28b2e4bba4 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -355,7 +355,7 @@ def kernel_unified_attention_2d( @triton.jit def kernel_unified_attention_3d( segm_output_ptr, - # [num_tokens, num_query_heads, num_segments, head_size] + # [num_tokens, num_query_heads, num_segments, head_size_padded] segm_max_ptr, # [num_tokens, num_query_heads, num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] query_ptr, # [num_tokens, num_query_heads, head_size] @@ -749,6 +749,11 @@ def unified_attention( q_descale, k_descale, v_descale, + seq_threshold_3D, + num_par_softmax_segments, + softmax_segm_output, + softmax_segm_max, + softmax_segm_expsum, alibi_slopes=None, output_scale=None, qq_bias=None, @@ -793,8 +798,8 @@ def unified_attention( TILE_SIZE_PREFILL = 32 TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 - # if batch contains a prefill - if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + # if batch contains a prefill or number of sequences is larger than threshold + if max_seqlen_q > 1 or num_seqs > seq_threshold_3D: kernel_unified_attention_2d[ ( total_num_q_blocks, @@ -847,37 +852,10 @@ def unified_attention( USE_FP8=output_scale is not None, ) else: - # for initial version, NUM_SEGMENTS = 16 is chosen as a default - # value that showed good performance in tests - NUM_SEGMENTS = 16 - - segm_output = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - triton.next_power_of_2(head_size), - dtype=torch.float32, - device=q.device, - ) - segm_max = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - segm_expsum = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - - kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, + kernel_unified_attention_3d[(num_seqs, num_kv_heads, num_par_softmax_segments)]( + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, @@ -917,13 +895,13 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, ) - reduce_segments[(q.shape[0], num_query_heads)]( + reduce_segments[(num_seqs, num_query_heads)]( output_ptr=out, - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, seq_lens_ptr=seqused_k, num_seqs=num_seqs, num_query_heads=num_query_heads, @@ -936,6 +914,6 @@ def unified_attention( HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, USE_FP8=output_scale is not None, ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 09c36043c8c8..24d148c674fc 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -17,7 +17,7 @@ triton_reshape_and_cache_flash, ) from vllm.attention.ops.triton_unified_attention import unified_attention -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -26,6 +26,7 @@ ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability +from vllm.utils.math_utils import next_power_of_2 from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -36,6 +37,11 @@ logger = init_logger(__name__) +# constants +MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel +NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments + + @dataclass class TritonAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -54,6 +60,12 @@ class TritonAttentionMetadata: block_table: torch.Tensor slot_mapping: torch.Tensor + seq_threshold_3D: int + num_par_softmax_segments: int + softmax_segm_output: torch.Tensor + softmax_segm_max: torch.Tensor + softmax_segm_expsum: torch.Tensor + # For cascade attention. use_cascade: bool common_prefix_len: int @@ -87,6 +99,69 @@ def __init__( self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() + # Check if CUDA Graphs are enabled for decode + self.decode_cudagraph_enabled = ( + self.vllm_config.compilation_config.cudagraph_mode + in ( + CUDAGraphMode.FULL_AND_PIECEWISE, + CUDAGraphMode.FULL_DECODE_ONLY, + CUDAGraphMode.FULL, + ) + ) + + # The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv). + # A lower bound for num_q_blocks is the number of sequences. + # To ensure the minimum launch grid size is achieved, the number of sequences + # must be at least equal to the threshold below. + # If this threshold is not reached (i.e., the batch size is not large enough), + # the 3D kernel will be selected instead. + self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv + + # Modify the threshold if needed. + if self.decode_cudagraph_enabled: + capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified." + + # Select the CUDA Graph capture size closest to self.seq_threshold_3D + # as threshold. This ensures that each captured graph covers the + # correct execution path. + upd_seq_threshold_3D = min( + capture_sizes, + key=lambda x: abs(x - self.seq_threshold_3D), + ) + + # If the updated threshold becomes significantly larger than the + # initial value, it is reset to zero. This enforces the use of the + # 2D kernel only and ensures that the size of the allocated + # intermediate structures remains bounded. + if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D: + self.seq_threshold_3D = upd_seq_threshold_3D + else: + self.seq_threshold_3D = 0 + + self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS + headdim_padded = next_power_of_2(self.headdim) + self.softmax_segm_output = torch.empty( + ( + self.seq_threshold_3D, + self.num_heads_q, + self.num_par_softmax_segments, + headdim_padded, + ), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_max = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_expsum = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: @@ -143,6 +218,11 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + seq_threshold_3D=self.seq_threshold_3D, + num_par_softmax_segments=self.num_par_softmax_segments, + softmax_segm_output=self.softmax_segm_output, + softmax_segm_max=self.softmax_segm_max, + softmax_segm_expsum=self.softmax_segm_expsum, ) return attn_metadata @@ -347,6 +427,12 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + seq_threshold_3D = attn_metadata.seq_threshold_3D + num_par_softmax_segments = attn_metadata.num_par_softmax_segments + softmax_segm_output = attn_metadata.softmax_segm_output + softmax_segm_max = attn_metadata.softmax_segm_max + softmax_segm_expsum = attn_metadata.softmax_segm_expsum + descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) unified_attention( @@ -367,6 +453,11 @@ def forward( q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, sinks=self.sinks, output_scale=output_scale, )