Skip to content

[Performance]: Fully Async Spec-Decoding | Make seq_lens_cpu in CommonAttentionMetadata optional #29134

@LucasWilkinson

Description

@LucasWilkinson

Proposal to improve performance

Currently fully overlapping input-prep with model forward pass is blocked in the spec-decode case by the following Host<>GPU syncs:

  1. _get_valid_sampled_token_count (ultimately needed to compute seq_lens_cpu):
    self.valid_sampled_token_count_event.synchronize()
  2. in the num_speculated_tokens > 1 case by needing to update seq_lens_cpu for attention metadata building (specifically for all the speculated tokens after the first one hence only impacting num_speculated_tokens > 1 context: [Bugfix] Invalidate positions when using padded speculative decoding #26498)

Ultimately in-order to realize fully async spec decoding we need to build attention metadata without knowing seq_lens_cpu (using device seq_lens on device is fine since any metadata building GPU kernels will get queued to after this is updated by the CUDA driver).

This is currently not entirely possible for all backends (namely FlashInfer do to D2H or H2D, depending on if a host or device tensor is provided, inside the plan function).

However, there are many (important/default) backends like TRTLLM (inside the FlashInfer backend), FlashAttn, FlashAttn-MLA (for pure decode), FlashMLA (for pure decode) etc. that could achieve this full overlap using only the device seq_lens tensor.

The proposal would be to update common attention metadata from:

@dataclass
class CommonAttentionMetadata:
    ...
    seq_lens: torch.Tensor
    seq_lens_cpu: torch.Tensor

To something like:

@dataclass
class CommonAttentionMetadata:
    ...
    seq_lens: torch.Tensor
    _seq_lens_cpu: torch.Tensor | None

    # WARNING: using this property will mean spec-decode with async-scheduling will not achieve
    # full overlap due to Host<>GPU sync
    @property
    def seq_lens_cpu(self):
           # Potentially log warning here to encourage developers to avoid this property
           return self. _seq_lens_cpu or seq_lens.to("cpu")

Where _seq_lens_cpu would be provided if available/known (e.g. async-scheduling but no-spec-decode or no-async-scheduling).

This will allow backends that do not need seq_lens_cpu to be able to benefit more form async-scheduling while maintaining support for backends that need seq_lens_cpu with little to no perf regression (we could potentially add warnings to developers to encourage them to migrate away from using this).

NOTE: other changes to GPUModelRunner._update_states would be needed to fully realize this (or model runner v2 #25266); but the case of num_speculated_tokens > 1 could benefit immediately by overlapping the drafter metadata prep with the target model's forward pass.

NOTE: this would also mean max_seq_len would no-longer represent the true max but an upper-bound that could be off by num_speculated_tokens; this shouldn't be a problem given most backends (e.g. FA) use this simply for heuristics. We may want to consider renaming this to seq_len_upper_bound in-order to make it clear to backend developers this may not be exact.

NOTE: num_computed_tokens_cpu would also need to be avoided but a similar approach could be used considering its a derivative of query_start_loc and seq_lens

@dataclass
class CommonAttentionMetadata:
    ...

    # WARNING: using this property will mean spec-decode with async-scheduling will not achieve
    # full overlap due to Host<>GPU sync
    @property
    def num_computed_tokens_cpu(self):
           # Potentially log warning here to encourage developers to avoid this property
           return seq_lens_cpu - (query_start_loc_cpu[1:] - query_start_loc_cpu[0:])

cc @benchislett @WoosukKwon @MatthewBonanni

Metadata

Metadata

Assignees

No one assigned

    Labels

    performancePerformance-related issues

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions