-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Spec Decode] Enable FlashInfer Spec Decoding #25196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spec Decode] Enable FlashInfer Spec Decoding #25196
Conversation
Co-authored-by: lhsjohn <huashuoli@tencent.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, might be nice to get some extra eyes on the FlashInfer bits @pavanimajety @mgoin
num_decodes:] | ||
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably naive q: Can there be cases in normal decode where num_decodes < num_decode_tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually, reorder_batch_size == 1
so num_decodes == num_decode_tokens
.
However, we're using a padded-batch speculative decoding implementation where we can use the trtllm-gen batch_decode
kernel for a batch of requests as long as they all have the same q_len
, which can be larger than 1.
So we need to fix a bunch of cases like this one, where we can have max_q_len * num_decodes
tokens in the decode pathway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, could you share some accuracy results for non spec-decoding and spec-decoding on?
@pavanimajety added correctness testing to the description. All successful. |
Also fixed the kernel crashes I was seeing by adding |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; left one nit
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…nfer-trtllm-spec-kernels
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
Purpose
This PR enables FlashInfer for speculative decoding. When possible, the
trtllm-gen
decode-optimized kernel is used for speculative decoding. The fallback case is the prefill kernel, which can handle arbitrary query lengths but is not as performant.This PR depends on #25183 for the refactor of the batch reordering threshold variable.
Here's an example launch command for EAGLE3 on 1xB200:
vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4}' --max-model-len 2048 --no-enable-prefix-caching
Benchmarking with 200 requests from ShareGPT gives the following TPS numbers:
Examination of
nsys
traces shows that the main model's attention kernels are about 2.5x faster when using the decode-optimized trtllm-gen kernels.Correctness Testing
Tested on GSM8k (limit 500) with/out spec and with/out FlashInfer. All successful.
FlashInfer + Padded-Batch (Default)
FlashInfer + Non-uniform Batch (opt-in / fallback option)
Baseline FlashInfer
Baseline FlashAttn
Known issuesResolved IssuesWhen CUDA graphs are enabled and the padded-drafter-batch is disabled, there is a crash at high concurrency. This PR fixes a couple issues in FlashInfer where the planning and building of metadata assume thatnum_decode_tokens
andnum_decodes
are the same. There is likely another such issue in the planning logic for cuda graph padding. The issue can be patched by using enforce-eager or recording a cuda graph for each input batch size.Update after some investigation. There are illegal memory accesses in the TRTLLM-gen kernels that produce difficult-to-reproduce crashes. It seems unrelated to the logic in this PR which I have marked as ready-for-review. It is possible that some state corruption or race condition is independently causing issues with these kernels on blackwell. I will continue investigatingUpdate: Fixed. Issue was
max_q_len
>max(query_lens)
causing illegal access for non-uniform batches such as[2, 1]
where the prefill max query size was smaller than the total max query size. Fixed by manually calculatingmax_q_len_prefill
after splitting the batch