-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Bugfix][LoRA][Spec Decode] Support LoRA with speculative decoding #21068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix][LoRA][Spec Decode] Support LoRA with speculative decoding #21068
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for LoRA with speculative decoding in the V1 engine, addressing a dimension mismatch issue that arises when speculative decoding is enabled. The changes involve adjusting the dimension of AdapterMapping.prompt_mapping to store LoRA IDs for each sampled token, ensuring compatibility with both enabled and disabled speculative decoding scenarios. The test plan includes running the V1 server with LoRA and speculative decoding enabled, covering single and multiple requests in parallel.
NickLucche
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind turning your scripts into tests, if not already present?
Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
… v1 on gpu Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
0fe83ca to
e90736f
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Have turned the script into a unit test. succeeded |
Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
38138b0 to
6d96013
Compare
|
@NickLucche Can you help take a look |
|
This pull request has merge conflicts that must be resolved before it can be |
| # to max_batches * (num_speculative_decoding_tokens + 1). | ||
| self.prompt_mapping_meta = LoRAKernelMeta.make( | ||
| self.max_loras, max_batches, device=device | ||
| self.max_loras, max_num_batched_tokens, device=device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @xiaohongchen1991 . I have seen configurations with max_batches, max_num_batched_tokens set as 1024, 8192. In such cases, it looks like there is a constraint on how big num_speculative_decoding_tokens can be. I think we should add an assert like assert(max_num_batched_tokens >= max_batches * (num_speculative_decoding_tokens + 1)) so we catch out-of-bounds errors.
What do you think ? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, should we just use max_batches if spec_decode is disabled ? It might be useful when debugging issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @varun-sundar-rabindranath . Assert added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @li2haipeng !
|
Thanks @xiaohongchen1991 for the great work - The changes generally look good to me. The preparation of I am a little out of sync with Spec. Decode - Particularly I am a bit confused as to what happens when |
Not sure If I understand it correctly but this situation would be guarded after adding the assert, right? |
| lora_requests: set[LoRARequest] | ||
| prompt_lora_mapping, token_lora_mapping, lora_requests = ( | ||
| input_batch.make_lora_inputs(num_scheduled_tokens) | ||
| input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@li2haipeng can you also add an assert after this line like,
assert(len(prompt_lora_mapping) <= self.max_num_batched_tokens)
My main concern is that, given we are doing
prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
in gpu_input_batch.py :: make_lora_inputs()
I wonder if len(prompt_lora_mapping) would exceed max_num_batched_tokens . If this happens, I think we will catch it here
vllm/vllm/lora/punica_wrapper/punica_base.py
Line 192 in da786e3
| self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices) |
Also, I think the interaction between max_batches * (num_speculative_decoding_tokens + 1) and max_num_batched_tokens should be captured and we should raise an error during engine startup when they are incompatible. For example, if the user creates an engine with,
LoRA + Spec Decode + max_num_seqs=512 + max_num_batched_tokens=1024 + num_speculative_decoding_tokens = 5
This will assert deep in the code - but it'll be much better to assert during startup (somewhere here
Line 1285 in da786e3
| def create_engine_config( |
max_num_batched_tokens.
What do you think ?
cc @robertgshaw2-redhat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with Varun
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varun-sundar-rabindranath Fixed. Thanks for your suggestion.
@li2haipeng can you record this restriction too please. i.e. set |
@varun-sundar-rabindranath @li2haipeng we don't need to set it to False because this PR fixed the issue: #28318. |
| "Qwen/Qwen3-1.7B", | ||
| "AngelSlim/Qwen3-1.7B_eagle3", | ||
| "premjatin/qwen-linear-algebra-coder", | ||
| 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given the issues around TP, it'd be good to add a TP = 2 test as well. But it can be a fast-follow after @28318 lands. Thanks.
varun-sundar-rabindranath
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @xiaohongchen1991 @li2haipeng @dcmaddix for working on this ❤️
| ), ( | ||
| "Consider increasing max_num_batched_tokens or " | ||
| "decreasing num_speculative_tokens" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit : Can you turn it into a ValueError to stay consistent with the error-raising mechanism in this file. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
robertgshaw2-redhat
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp post varun's review
Head branch was pushed to by a user without write access
…llm-project#21068) Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com> Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Danielle Robinson <dcmaddix@gmail.com> Co-authored-by: Haipeng Li <li2haipeng@gmail.com> Co-authored-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Bump vLLM version to v0.11.2 What's broken and changed by vLLM: 1. structured_output is broken by vllm-project/vllm#26866 2. get_mrope_input_positions is broken by vllm-project/vllm#28399 3. graph mode is broken by vllm-project/vllm#25110 we'll upgrade torch to 2.8 to fix the problem later 4. embedding is broken by vllm-project/vllm#27583 5. `get_attn_backend_cls` and attention backend is broken are broken by vllm-project/vllm#28534 6. spec decode is broken by vllm-project/vllm#28771 7. sp feature is broken by vllm-project/vllm#27126 8. mtp is broken by vllm-project/vllm#27922 9. lora is broken by vllm-project/vllm#21068 10. execute_model is broken by vllm-project/vllm#26866 11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by vllm-project/vllm#28159 12. kv cahe is broken by vllm-project/vllm#27753 13. dp is broken by vllm-project/vllm#25110 What's broken and changed by ourself: 1. qwen vl is broken by vllm-project/vllm#28455 We'll remove model files in the future to avoid this kind of error 2. Engine core is broken by vllm-project/vllm#23691 We'll remove the patch file in the future. 3. Ascend scheduler is broken by vllm-project/vllm#28733 We'll remove ascend scheudler later. 4. qwen3-next is broken by vllm-project/vllm#28083 We'll remove model files in the future to avoid this kind of error 5. qwen vl is broken by vllm-project/vllm#27764. We'll remove model files in the future Known issue: 1. ray doesn't work 2. the accuracy of qwen3-next is not correct 3. qwen3-vl is broken 4. prefix cache+ ascend scheduler + deepseek v2 lite is broken. Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: 22dimensions <waitingwind@foxmail.com> Co-authored-by: shen-shanshan <467638484@qq.com> - vLLM version: v0.11.2 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com>
Bump vLLM version to v0.11.2 What's broken and changed by vLLM: 1. structured_output is broken by vllm-project/vllm#26866 2. get_mrope_input_positions is broken by vllm-project/vllm#28399 3. graph mode is broken by vllm-project/vllm#25110 we'll upgrade torch to 2.8 to fix the problem later 4. embedding is broken by vllm-project/vllm#27583 5. `get_attn_backend_cls` and attention backend is broken are broken by vllm-project/vllm#28534 6. spec decode is broken by vllm-project/vllm#28771 7. sp feature is broken by vllm-project/vllm#27126 8. mtp is broken by vllm-project/vllm#27922 9. lora is broken by vllm-project/vllm#21068 10. execute_model is broken by vllm-project/vllm#26866 11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by vllm-project/vllm#28159 12. kv cahe is broken by vllm-project/vllm#27753 13. dp is broken by vllm-project/vllm#25110 What's broken and changed by ourself: 1. qwen vl is broken by vllm-project/vllm#28455 We'll remove model files in the future to avoid this kind of error 2. Engine core is broken by vllm-project/vllm#23691 We'll remove the patch file in the future. 3. Ascend scheduler is broken by vllm-project/vllm#28733 We'll remove ascend scheudler later. 4. qwen3-next is broken by vllm-project/vllm#28083 We'll remove model files in the future to avoid this kind of error 5. qwen vl is broken by vllm-project/vllm#27764. We'll remove model files in the future Known issue: 1. ray doesn't work 2. the accuracy of qwen3-next is not correct 3. qwen3-vl is broken 4. prefix cache+ ascend scheduler + deepseek v2 lite is broken. Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: 22dimensions <waitingwind@foxmail.com> Co-authored-by: shen-shanshan <467638484@qq.com> - vLLM version: v0.11.2 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Signed-off-by: Kurumi5210 <Jaychou1620@Gmail.com>
…llm-project#21068) Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com> Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Danielle Robinson <dcmaddix@gmail.com> Co-authored-by: Haipeng Li <li2haipeng@gmail.com> Co-authored-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com>
Purpose
This PR is for supporting LoRA with Speculative Decoding for V1 engine on gpu (tpu is not covered by this PR).
In the current V1 implementation for LoRA, it assumes 1 sampled token for each prompt in each forward pass step. See the logits computing logics,
where the logits.size(0) is sum(num of sampled tokens for each prompts) while lora_logits.size(0) is num of prompts.
This implementation works if the
num_sample_tokensis 1 when running the server without speculative decoding. If speculative decoding is enabled,num_sample_tokenswill becomenum_speculative_decoding_tokens + 1. The code fail with error messages likeThis commit adjusts the dimension of AdapterMapping.prompt_mapping so that it stores LoRA id for each sampled token. If speculative decoding is disabled, its size will become the number of prompts and will store LoRA id for each prompt, which is the same as the original implementation.
Test Plan
Running the V1 server with both LoRA and Speculative Decoding enabled. Test scenarios:
Test Result
See the dimension of logits and lora_logits match in logits computing.
(Optional) Documentation Update