Skip to content

Conversation

benchislett
Copy link
Collaborator

@benchislett benchislett commented Sep 18, 2025

Purpose

This PR makes common changes needed to enable FlashInfer, FlashInferMLA, FlashMLA, and other new backends for speculative decoding. This PR does not add explicit support for any one of these.

Included in this PR is:

  • A change to reorder_batch_threshold, making it no longer a ClassVar. This is because is can now be specialized at initialization time depending on whether or not speculative decoding is enabled: when it is, we can set it to num_speculative_tokens + 1 so that all spec-verify can be classified as decodes. A helper function is also included to facilitate this
  • "uniform" mode for batch splitting, which conservatively splits the batch assuming that all decodes must have the same query length. All others are classified as prefills for safety. Tests are included to validate correctness of this method. It is disabled by default.
  • Helper functions to reshape the attention tensors to add/remove a query_length axis. This is useful for many MLA backends which require the input query to have an explicit dimension for the qlen.

Test Plan

See tests/v1/attention/test_attention_splitting.py

Test Result

All passing locally


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several utilities and refactors to support speculative decoding with query lengths greater than one. The changes include making reorder_batch_threshold an instance variable for dynamic configuration, adding a 'uniform' mode for batch splitting, and providing helper functions for tensor reshaping. The new logic for batch splitting is well-tested and appears correct. The refactoring improves code clarity and prepares the codebase for new speculative decoding backends. I have one suggestion to improve an assertion in a new helper function for clarity and correctness.

@@ -766,6 +798,40 @@ def reorder_batch_to_split_decodes_and_prefills(
return modified_batch


def reshape_query_for_spec_decode(query: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these not used yet? should we just include them in the follow-up once they are actually used? or maybe we should add FlashMLA support in this PR? Just so everything is used (and tested since we can do a FlashMLA + MTP lm_eval run)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is worth committing now and using in subsequent PRs mostly because it will be used by FlashMLA and also FlashInfer-MLA and maybe more. Merging here as a helper means that all the downstream PRs can reuse the same code from main instead of duplicating it in each.

But I don't feel particularly strongly about this, and can remove if you think it's better to add separately.

benchislett and others added 3 commits September 18, 2025 18:06
Co-authored-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Copy link

mergify bot commented Sep 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 23, 2025
@benchislett
Copy link
Collaborator Author

Closing, see #25196

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants