Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 154 additions & 95 deletions vllm/v1/worker/gpu/cudagraph_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch
from collections.abc import Callable, Iterable
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -32,6 +33,7 @@ def __init__(

self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
Expand All @@ -40,102 +42,60 @@ def __init__(
self.cudagraph_mode = CUDAGraphMode.NONE
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.compilation_config.cudagraph_capture_sizes is not None:
cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
# Limit the cudagraph sizes to the max decode batch size.
self.cudagraph_sizes = [
x for x in cudagraph_sizes if x <= self.max_num_reqs
]
else:
self.cudagraph_sizes = []
self.padded_sizes = self._init_padded_sizes()
self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
)

self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None

def _init_padded_sizes(self) -> dict[int, int]:
if not self.cudagraph_mode.has_full_cudagraphs():
# Full cuda graphs are not used.
return {}
if not self.cudagraph_sizes:
return {}

padded_sizes: dict[int, int] = {}
for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes:
if i <= x:
padded_sizes[i] = x
break
return padded_sizes

def needs_capture(self) -> bool:
return len(self.padded_sizes) > 0
return len(self.cudagraph_sizes) > 0

def get_cudagraph_size(
self,
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
) -> int | None:
if not self.cudagraph_mode.has_full_cudagraphs():
return None
if self.cudagraph_mode != CUDAGraphMode.FULL:
# TODO(woosuk): Support uniform decode with multiple tokens (spec decoding).
all_decode = all(
x == 1 for x in scheduler_output.num_scheduled_tokens.values()
)
if not all_decode:
# Prefill is included.
return None
return self.padded_sizes.get(num_tokens_after_padding)
return get_cudagraph_size(
num_tokens_after_padding,
scheduler_output.num_scheduled_tokens.values(),
self.cudagraph_sizes,
self.cudagraph_mode,
)

def capture_graph(
self,
batch_size: int,
num_tokens: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert batch_size not in self.graphs

# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions[:batch_size]

input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
# seq_lens_np (CPU), which might cause issues in some attention backends.
input_buffers.seq_lens[:batch_size] = 1
input_buffers.seq_lens[batch_size:] = 0

input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]

attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=batch_size,
num_tokens=batch_size,
query_start_loc_gpu=input_buffers.query_start_loc.gpu[: batch_size + 1],
query_start_loc_cpu=input_buffers.query_start_loc.cpu[: batch_size + 1],
seq_lens=input_buffers.seq_lens,
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids.gpu[:num_tokens]
positions = input_buffers.positions[:num_tokens]
attn_metadata = prepare_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_metadata_builders,
self.max_model_len,
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, batch_size)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)

# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
Expand All @@ -147,13 +107,13 @@ def capture_graph(
self.hidden_states = torch.empty_like(hidden_states)

# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with (
patch("torch.cuda.empty_cache", lambda: None),
set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
),
Expand All @@ -163,8 +123,8 @@ def capture_graph(
input_ids=input_ids,
positions=positions,
)
self.hidden_states[:batch_size] = hidden_states
self.graphs[batch_size] = graph
self.hidden_states[:num_tokens] = hidden_states
self.graphs[num_tokens] = graph

@torch.inference_mode()
def capture(
Expand All @@ -175,25 +135,124 @@ def capture(
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert self.needs_capture()
# Capture larger graphs first.
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")

with graph_capture(device=self.device):
for batch_size in sizes_to_capture:
self.capture_graph(
batch_size,
model,
input_buffers,
block_tables,
attn_metadata_builders,
kv_cache_config,
)

def run(self, batch_size: int) -> torch.Tensor:
assert batch_size in self.graphs
self.graphs[batch_size].replay()
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
model=model,
input_buffers=input_buffers,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,
)

def run(self, num_tokens: int) -> torch.Tensor:
assert num_tokens in self.graphs
self.graphs[num_tokens].replay()
assert self.hidden_states is not None
return self.hidden_states[:batch_size]
return self.hidden_states[:num_tokens]


def get_cudagraph_sizes(
capture_sizes: list[int] | None,
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
) -> dict[int, int]:
if not cudagraph_mode.has_full_cudagraphs():
return {}
if not capture_sizes:
return {}

capture_sizes = sorted(capture_sizes)
# Limit the capture sizes to the max number of requests or tokens.
upper_bound = (
max_num_reqs
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
else max_num_tokens
)
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
if not capture_sizes:
return {}

cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
for x in capture_sizes:
if i <= x:
cudagraph_sizes[i] = x
break
return cudagraph_sizes


def get_cudagraph_size(
num_tokens_after_dp_padding: int,
num_tokens_per_request: Iterable[int],
cudagraph_sizes: dict[int, int],
cudagraph_mode: CUDAGraphMode,
) -> int | None:
size = cudagraph_sizes.get(num_tokens_after_dp_padding)
if size is None:
# No CUDA graph for this size.
return None
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
all_decode = all(x == 1 for x in num_tokens_per_request)
if not all_decode:
# Prefill is included.
return None
return size


def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")

with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, **capture_kwargs)


def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc = input_buffers.query_start_loc
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
query_start_loc.np[num_reqs:] = num_tokens
query_start_loc.copy_to_gpu()
seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
# seq_lens_np (CPU), which might cause issues in some attention backends.
input_buffers.seq_lens[:num_reqs] = 1
input_buffers.seq_lens[num_reqs:] = 0

input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]

attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
seq_lens=input_buffers.seq_lens,
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
return attn_metadata