diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index b5fc2edea130..8f1718e493b1 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from unittest.mock import patch +from collections.abc import Callable, Iterable +from typing import Any import numpy as np import torch @@ -32,6 +33,7 @@ def __init__( self.max_model_len = vllm_config.model_config.max_model_len self.max_num_reqs = self.scheduler_config.max_num_seqs + self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.dp_size = vllm_config.parallel_config.data_parallel_size self.compilation_config = vllm_config.compilation_config assert self.compilation_config is not None @@ -40,102 +42,60 @@ def __init__( self.cudagraph_mode = CUDAGraphMode.NONE else: self.cudagraph_mode = self.compilation_config.cudagraph_mode - if self.compilation_config.cudagraph_capture_sizes is not None: - cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) - # Limit the cudagraph sizes to the max decode batch size. - self.cudagraph_sizes = [ - x for x in cudagraph_sizes if x <= self.max_num_reqs - ] - else: - self.cudagraph_sizes = [] - self.padded_sizes = self._init_padded_sizes() + self.cudagraph_sizes = get_cudagraph_sizes( + self.compilation_config.cudagraph_capture_sizes, + self.max_num_reqs, + self.max_num_tokens, + self.cudagraph_mode, + ) self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.pool = torch.cuda.graph_pool_handle() self.hidden_states: torch.Tensor | None = None - def _init_padded_sizes(self) -> dict[int, int]: - if not self.cudagraph_mode.has_full_cudagraphs(): - # Full cuda graphs are not used. - return {} - if not self.cudagraph_sizes: - return {} - - padded_sizes: dict[int, int] = {} - for i in range(1, self.cudagraph_sizes[-1] + 1): - for x in self.cudagraph_sizes: - if i <= x: - padded_sizes[i] = x - break - return padded_sizes - def needs_capture(self) -> bool: - return len(self.padded_sizes) > 0 + return len(self.cudagraph_sizes) > 0 def get_cudagraph_size( self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int, ) -> int | None: - if not self.cudagraph_mode.has_full_cudagraphs(): - return None - if self.cudagraph_mode != CUDAGraphMode.FULL: - # TODO(woosuk): Support uniform decode with multiple tokens (spec decoding). - all_decode = all( - x == 1 for x in scheduler_output.num_scheduled_tokens.values() - ) - if not all_decode: - # Prefill is included. - return None - return self.padded_sizes.get(num_tokens_after_padding) + return get_cudagraph_size( + num_tokens_after_padding, + scheduler_output.num_scheduled_tokens.values(), + self.cudagraph_sizes, + self.cudagraph_mode, + ) def capture_graph( self, - batch_size: int, + num_tokens: int, model: nn.Module, input_buffers: InputBuffers, block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, ) -> None: - assert batch_size not in self.graphs - - # Prepare dummy inputs. - input_ids = input_buffers.input_ids.gpu[:batch_size] - positions = input_buffers.positions[:batch_size] - - input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1) - input_buffers.query_start_loc.np[batch_size:] = batch_size - input_buffers.query_start_loc.copy_to_gpu() - # HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len) - # for seq_lens. This leads to a mismatch between seq_lens (GPU) and - # seq_lens_np (CPU), which might cause issues in some attention backends. - input_buffers.seq_lens[:batch_size] = 1 - input_buffers.seq_lens[batch_size:] = 0 - - input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables] - slot_mappings = block_tables.slot_mappings[:, :batch_size] - - attn_metadata = build_attn_metadata( - attn_metadata_builders=attn_metadata_builders, - num_reqs=batch_size, - num_tokens=batch_size, - query_start_loc_gpu=input_buffers.query_start_loc.gpu[: batch_size + 1], - query_start_loc_cpu=input_buffers.query_start_loc.cpu[: batch_size + 1], - seq_lens=input_buffers.seq_lens, - seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32), - num_computed_tokens_cpu=None, # FIXME - block_tables=input_block_tables, - slot_mappings=slot_mappings, - kv_cache_config=kv_cache_config, + num_reqs = min(num_tokens, self.max_num_reqs) + input_ids = input_buffers.input_ids.gpu[:num_tokens] + positions = input_buffers.positions[:num_tokens] + attn_metadata = prepare_inputs_to_capture( + num_reqs, + num_tokens, + input_buffers, + block_tables, + attn_metadata_builders, + self.max_model_len, + kv_cache_config, ) - num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, batch_size) + num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) # Warm up. with set_forward_context( attn_metadata, self.vllm_config, - num_tokens=batch_size, + num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, ): @@ -147,13 +107,13 @@ def capture_graph( self.hidden_states = torch.empty_like(hidden_states) # Capture the graph. + assert num_tokens not in self.graphs graph = torch.cuda.CUDAGraph() with ( - patch("torch.cuda.empty_cache", lambda: None), set_forward_context( attn_metadata, self.vllm_config, - num_tokens=batch_size, + num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, ), @@ -163,8 +123,8 @@ def capture_graph( input_ids=input_ids, positions=positions, ) - self.hidden_states[:batch_size] = hidden_states - self.graphs[batch_size] = graph + self.hidden_states[:num_tokens] = hidden_states + self.graphs[num_tokens] = graph @torch.inference_mode() def capture( @@ -175,25 +135,124 @@ def capture( attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, ) -> None: - assert self.needs_capture() - # Capture larger graphs first. - sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True) - if is_global_first_rank(): - sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") - - with graph_capture(device=self.device): - for batch_size in sizes_to_capture: - self.capture_graph( - batch_size, - model, - input_buffers, - block_tables, - attn_metadata_builders, - kv_cache_config, - ) - - def run(self, batch_size: int) -> torch.Tensor: - assert batch_size in self.graphs - self.graphs[batch_size].replay() + capture_graphs( + self.cudagraph_sizes, + self.device, + self.capture_graph, + model=model, + input_buffers=input_buffers, + block_tables=block_tables, + attn_metadata_builders=attn_metadata_builders, + kv_cache_config=kv_cache_config, + ) + + def run(self, num_tokens: int) -> torch.Tensor: + assert num_tokens in self.graphs + self.graphs[num_tokens].replay() assert self.hidden_states is not None - return self.hidden_states[:batch_size] + return self.hidden_states[:num_tokens] + + +def get_cudagraph_sizes( + capture_sizes: list[int] | None, + max_num_reqs: int, + max_num_tokens: int, + cudagraph_mode: CUDAGraphMode, +) -> dict[int, int]: + if not cudagraph_mode.has_full_cudagraphs(): + return {} + if not capture_sizes: + return {} + + capture_sizes = sorted(capture_sizes) + # Limit the capture sizes to the max number of requests or tokens. + upper_bound = ( + max_num_reqs + if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY + else max_num_tokens + ) + capture_sizes = [x for x in capture_sizes if x <= upper_bound] + if not capture_sizes: + return {} + + cudagraph_sizes: dict[int, int] = {} + for i in range(1, capture_sizes[-1] + 1): + for x in capture_sizes: + if i <= x: + cudagraph_sizes[i] = x + break + return cudagraph_sizes + + +def get_cudagraph_size( + num_tokens_after_dp_padding: int, + num_tokens_per_request: Iterable[int], + cudagraph_sizes: dict[int, int], + cudagraph_mode: CUDAGraphMode, +) -> int | None: + size = cudagraph_sizes.get(num_tokens_after_dp_padding) + if size is None: + # No CUDA graph for this size. + return None + if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + all_decode = all(x == 1 for x in num_tokens_per_request) + if not all_decode: + # Prefill is included. + return None + return size + + +def capture_graphs( + cudagraph_sizes: dict[int, int], + device: torch.device, + capture_fn: Callable, + **capture_kwargs, +) -> None: + # Capture larger graphs first. + sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True) + if is_global_first_rank(): + sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") + + with graph_capture(device=device): + for size in sizes_to_capture: + capture_fn(size, **capture_kwargs) + + +def prepare_inputs_to_capture( + num_reqs: int, + num_tokens: int, + input_buffers: InputBuffers, + block_tables: BlockTables, + attn_metadata_builders: list[AttentionMetadataBuilder], + max_model_len: int, + kv_cache_config: KVCacheConfig, +) -> dict[str, Any]: + num_tokens_per_req = num_tokens // num_reqs + query_start_loc = input_buffers.query_start_loc + query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req + query_start_loc.np[num_reqs:] = num_tokens + query_start_loc.copy_to_gpu() + seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32) + # HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len) + # for seq_lens. This leads to a mismatch between seq_lens (GPU) and + # seq_lens_np (CPU), which might cause issues in some attention backends. + input_buffers.seq_lens[:num_reqs] = 1 + input_buffers.seq_lens[num_reqs:] = 0 + + input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] + slot_mappings = block_tables.slot_mappings[:, :num_tokens] + + attn_metadata = build_attn_metadata( + attn_metadata_builders=attn_metadata_builders, + num_reqs=num_reqs, + num_tokens=num_tokens, + query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1], + query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1], + seq_lens=input_buffers.seq_lens, + seq_lens_np=seq_lens_np, + num_computed_tokens_cpu=None, # FIXME + block_tables=input_block_tables, + slot_mappings=slot_mappings, + kv_cache_config=kv_cache_config, + ) + return attn_metadata