Skip to content

Conversation

lhtin
Copy link
Contributor

@lhtin lhtin commented Sep 29, 2025

Use seq_lens_cpu instead of seq_lens to reduce GPU->CPU sync.

Purpose

Reduce unnecessary GPU->CPU sync, since it will affect the perf of Async Scheduling+MTP.
Clipboard_Screenshot_1759149439

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Lehua Ding <lehuading@tencent.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization to reduce GPU-to-CPU synchronization during speculative decoding. The change replaces a call to .max() on a GPU tensor (seq_lens) with its CPU counterpart (seq_lens_cpu) within a conditional check. This avoids a blocking operation, which is particularly beneficial for asynchronous scheduling. The change is correct and aligns with the stated goal of improving performance. I have no further comments.

@lhtin
Copy link
Contributor Author

lhtin commented Sep 29, 2025

Which introduce by this pr(#24662), @AlonKejzman can you help review this too? thanks.

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

self.speculative_config.draft_model_config.max_model_len)
input_fits_in_drafter = spec_decode_common_attn_metadata and (
spec_decode_common_attn_metadata.seq_lens.max() +
spec_decode_common_attn_metadata.seq_lens_cpu.max() +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use .max_seq_len here?

Copy link
Contributor Author

@lhtin lhtin Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, and it is simpler! spec_decode_common_attn_metadata.max_seq_len is cmoputed from self.seq_lens.np[:num_reqs].max().item(), they are equivalent.

@benchislett benchislett added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 29, 2025
Signed-off-by: Lehua Ding <lehuading@tencent.com>
@DarkLight1337 DarkLight1337 merged commit e184c9c into vllm-project:main Sep 30, 2025
42 checks passed
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
Signed-off-by: Lehua Ding <lehuading@tencent.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Lehua Ding <lehuading@tencent.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants