-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Spec Decode] Utilities and refactor to support qlen>1 decode kernels for spec decode #25183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spec Decode] Utilities and refactor to support qlen>1 decode kernels for spec decode #25183
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several utilities and refactors to support speculative decoding with query lengths greater than one. The changes include making reorder_batch_threshold
an instance variable for dynamic configuration, adding a 'uniform' mode for batch splitting, and providing helper functions for tensor reshaping. The new logic for batch splitting is well-tested and appears correct. The refactoring improves code clarity and prepares the codebase for new speculative decoding backends. I have one suggestion to improve an assertion in a new helper function for clarity and correctness.
@@ -766,6 +798,40 @@ def reorder_batch_to_split_decodes_and_prefills( | |||
return modified_batch | |||
|
|||
|
|||
def reshape_query_for_spec_decode(query: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these not used yet? should we just include them in the follow-up once they are actually used? or maybe we should add FlashMLA support in this PR? Just so everything is used (and tested since we can do a FlashMLA + MTP lm_eval
run)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is worth committing now and using in subsequent PRs mostly because it will be used by FlashMLA and also FlashInfer-MLA and maybe more. Merging here as a helper means that all the downstream PRs can reuse the same code from main instead of duplicating it in each.
But I don't feel particularly strongly about this, and can remove if you think it's better to add separately.
Co-authored-by: lhsjohn <huashuoli@tencent.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
6eae35e
to
cfa3273
Compare
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This pull request has merge conflicts that must be resolved before it can be |
Closing, see #25196 |
Purpose
This PR makes common changes needed to enable FlashInfer, FlashInferMLA, FlashMLA, and other new backends for speculative decoding. This PR does not add explicit support for any one of these.
Included in this PR is:
reorder_batch_threshold
, making it no longer aClassVar
. This is because is can now be specialized at initialization time depending on whether or not speculative decoding is enabled: when it is, we can set it tonum_speculative_tokens + 1
so that all spec-verify can be classified as decodes. A helper function is also included to facilitate thisTest Plan
See
tests/v1/attention/test_attention_splitting.py
Test Result
All passing locally
Essential Elements of an Effective PR Description Checklist