-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Enable Eagle3 speculative decoding for GPT-OSS model #25246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Eagle3 speculative decoding for GPT-OSS model #25246
Conversation
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables Eagle3 speculative decoding for the GPT-OSS model. The changes are generally well-implemented, including adding the model to the supported list, implementing the SupportsEagle3
interface, and generalizing the embedding layer sharing logic. However, I've identified a critical issue in GptOssModel.forward
where a TypeError
can occur due to improper handling of a None
residual on the first layer of a pipeline stage. I have provided a code suggestion to fix this bug.
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but let's wait for someone more familiar with eagle perhaps.
Would be great to have smoke tests for model init at least with layer-capped versions..
f"{self.disable_by_batch_size=}") | ||
|
||
eagle3_target_supported = ["llama", "qwen"] | ||
eagle3_target_supported = ["llama", "qwen", "gpt_oss"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated to this PR, but I wonder why do we have to list models here instead of relying on SupportsEagle3
dispatching
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest not sure, I just followed llama.py / qwen.py to see how eagle is enabled there
logger.info( | ||
"Assuming the EAGLE head shares the same vocab embedding" | ||
" with the target model.") | ||
del self.model.model.embed_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be picked up by gc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took this from llama.py and qwen.py. Should I remove it here or leave it for consistency?
Noting #23596, which implements similarly but includes blackwell support via FlashInfer. Should be fine to merge this one first, but we should make sure they're consistent as we will need blackwell support as well |
I think for models with alternative attentions like gpt-oss, you need to find correct attention builders for draft model (draft model is full attention) like https://github.com/vllm-project/vllm/pull/23596/files#diff-a4809a837fbf535a8f0999b11087a53ec1c53948b50c0a1fe64396bc86de9461R883-R906 . Without this, it will use the sliding window attention, which will have accuracy issues (target model KV will be overwritten by draft model) |
The goal of this one was to first add support for the simplest Llama-like-speculator from Eagle3. And then we can build on top of it for more complex archs. This is unblock some preliminary experimentation with speculative decoding. |
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This PR adds support for EAGLE3 speculative decoding for GPT-OSS model. Changes tested with a locally trained speculator model, and observed reasonable acceptance rates.