Skip to content

Conversation

benchislett
Copy link
Collaborator

@benchislett benchislett commented Sep 18, 2025

Purpose

This PR enables FlashInfer for speculative decoding. When possible, the trtllm-gen decode-optimized kernel is used for speculative decoding. The fallback case is the prefill kernel, which can handle arbitrary query lengths but is not as performant.

This PR depends on #25183 for the refactor of the batch reordering threshold variable.

Here's an example launch command for EAGLE3 on 1xB200:

vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4}' --max-model-len 2048 --no-enable-prefix-caching

Benchmarking with 200 requests from ShareGPT gives the following TPS numbers:

  • Padding enabled + FlashInfer: 530 TPS (1.14x)
  • Padding disabled + FlashInfer: 465 TPS (1.0x)
  • Padding enabled + FlashAttention: 492 TPS (1.06x)
  • Padding disabled + FlashAttention: 466 TPS (1.0x)

Examination of nsys traces shows that the main model's attention kernels are about 2.5x faster when using the decode-optimized trtllm-gen kernels.

Correctness Testing

Tested on GSM8k (limit 500) with/out spec and with/out FlashInfer. All successful.

FlashInfer + Padded-Batch (Default)

CUDA_VISIBLE_DEVICES=3 vllm serve meta-llama/Llama-3.1-8B-Instruct --max-model-len 2048 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4}' 
limit: 500.0, num_fewshot: 5, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.796|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.774|±  |0.0187|

FlashInfer + Non-uniform Batch (opt-in / fallback option)

CUDA_VISIBLE_DEVICES=3 vllm serve meta-llama/Llama-3.1-8B-Instruct --max-model-len 2048 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4, "disable_padded_drafter_batch": true}' 
limit: 500.0, num_fewshot: 5, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.792|±  |0.0182|
|     |       |strict-match    |     5|exact_match|↑  |0.774|±  |0.0187|

Baseline FlashInfer

CUDA_VISIBLE_DEVICES=3 vllm serve meta-llama/Llama-3.1-8B-Instruct --max-model-len 2048 --no-enable-prefix-caching --port 8049
limit: 500.0, num_fewshot: 5, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.796|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.778|±  |0.0186|

Baseline FlashAttn

CUDA_VISIBLE_DEVICES=3 VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve meta-llama/Llama-3.1-8B-Instruct --max-model-len 2048 --no-enable-prefix-caching --port 8049
limit: 500.0, num_fewshot: 5, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.794|±  |0.0181|
|     |       |strict-match    |     5|exact_match|↑  |0.772|±  |0.0188|

Known issues Resolved Issues

When CUDA graphs are enabled and the padded-drafter-batch is disabled, there is a crash at high concurrency. This PR fixes a couple issues in FlashInfer where the planning and building of metadata assume that num_decode_tokens and num_decodes are the same. There is likely another such issue in the planning logic for cuda graph padding. The issue can be patched by using enforce-eager or recording a cuda graph for each input batch size.

Update after some investigation. There are illegal memory accesses in the TRTLLM-gen kernels that produce difficult-to-reproduce crashes. It seems unrelated to the logic in this PR which I have marked as ready-for-review. It is possible that some state corruption or race condition is independently causing issues with these kernels on blackwell. I will continue investigating

Update: Fixed. Issue was max_q_len > max(query_lens) causing illegal access for non-uniform batches such as [2, 1] where the prefill max query size was smaller than the total max query size. Fixed by manually calculating max_q_len_prefill after splitting the batch

benchislett and others added 4 commits September 18, 2025 18:06
Co-authored-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, might be nice to get some extra eyes on the FlashInfer bits @pavanimajety @mgoin

Comment on lines +877 to +878
num_decodes:]
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably naive q: Can there be cases in normal decode where num_decodes < num_decode_tokens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, reorder_batch_size == 1 so num_decodes == num_decode_tokens.

However, we're using a padded-batch speculative decoding implementation where we can use the trtllm-gen batch_decode kernel for a batch of requests as long as they all have the same q_len, which can be larger than 1.

So we need to fix a bunch of cases like this one, where we can have max_q_len * num_decodes tokens in the decode pathway

Copy link
Contributor

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, could you share some accuracy results for non spec-decoding and spec-decoding on?

@benchislett
Copy link
Collaborator Author

@pavanimajety added correctness testing to the description. All successful.

@benchislett
Copy link
Collaborator Author

Also fixed the kernel crashes I was seeing by adding max_q_len_prefill to FlashInferMetadata. Open to discussing better ways to handle this, maybe we can completely use max_q_len? It seems like the only place it is used but it seems ambiguous

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; left one nit

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett enabled auto-merge (squash) September 23, 2025 15:23
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 23, 2025
Copy link

mergify bot commented Sep 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 23, 2025
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify mergify bot removed the needs-rebase label Sep 23, 2025
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett merged commit c30b405 into vllm-project:main Sep 24, 2025
43 checks passed
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: lhsjohn <huashuoli@tencent.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants