-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Description
Proposal to improve performance
Currently fully overlapping input-prep with model forward pass is blocked in the spec-decode case by the following Host<>GPU syncs:
_get_valid_sampled_token_count(ultimately needed to computeseq_lens_cpu):vllm/vllm/v1/worker/gpu_model_runner.py
Line 3109 in fe25772
self.valid_sampled_token_count_event.synchronize() - in the
num_speculated_tokens > 1case by needing to updateseq_lens_cpufor attention metadata building (specifically for all the speculated tokens after the first one hence only impactingnum_speculated_tokens > 1context: [Bugfix] Invalidate positions when using padded speculative decoding #26498)
Ultimately in-order to realize fully async spec decoding we need to build attention metadata without knowing seq_lens_cpu (using device seq_lens on device is fine since any metadata building GPU kernels will get queued to after this is updated by the CUDA driver).
This is currently not entirely possible for all backends (namely FlashInfer do to D2H or H2D, depending on if a host or device tensor is provided, inside the plan function).
However, there are many (important/default) backends like TRTLLM (inside the FlashInfer backend), FlashAttn, FlashAttn-MLA (for pure decode), FlashMLA (for pure decode) etc. that could achieve this full overlap using only the device seq_lens tensor.
The proposal would be to update common attention metadata from:
@dataclass
class CommonAttentionMetadata:
...
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
To something like:
@dataclass
class CommonAttentionMetadata:
...
seq_lens: torch.Tensor
_seq_lens_cpu: torch.Tensor | None
# WARNING: using this property will mean spec-decode with async-scheduling will not achieve
# full overlap due to Host<>GPU sync
@property
def seq_lens_cpu(self):
# Potentially log warning here to encourage developers to avoid this property
return self. _seq_lens_cpu or seq_lens.to("cpu")
Where _seq_lens_cpu would be provided if available/known (e.g. async-scheduling but no-spec-decode or no-async-scheduling).
This will allow backends that do not need seq_lens_cpu to be able to benefit more form async-scheduling while maintaining support for backends that need seq_lens_cpu with little to no perf regression (we could potentially add warnings to developers to encourage them to migrate away from using this).
NOTE: other changes to GPUModelRunner._update_states would be needed to fully realize this (or model runner v2 #25266); but the case of num_speculated_tokens > 1 could benefit immediately by overlapping the drafter metadata prep with the target model's forward pass.
NOTE: this would also mean max_seq_len would no-longer represent the true max but an upper-bound that could be off by num_speculated_tokens; this shouldn't be a problem given most backends (e.g. FA) use this simply for heuristics. We may want to consider renaming this to seq_len_upper_bound in-order to make it clear to backend developers this may not be exact.
NOTE: num_computed_tokens_cpu would also need to be avoided but a similar approach could be used considering its a derivative of query_start_loc and seq_lens
@dataclass
class CommonAttentionMetadata:
...
# WARNING: using this property will mean spec-decode with async-scheduling will not achieve
# full overlap due to Host<>GPU sync
@property
def num_computed_tokens_cpu(self):
# Potentially log warning here to encourage developers to avoid this property
return seq_lens_cpu - (query_start_loc_cpu[1:] - query_start_loc_cpu[0:])