Skip to content

Conversation

@xiaohongchen1991
Copy link
Contributor

@xiaohongchen1991 xiaohongchen1991 commented Jul 16, 2025

Purpose

This PR is for supporting LoRA with Speculative Decoding for V1 engine on gpu (tpu is not covered by this PR).

In the current V1 implementation for LoRA, it assumes 1 sampled token for each prompt in each forward pass step. See the logits computing logics,

logits[:,self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits

where the logits.size(0) is sum(num of sampled tokens for each prompts) while lora_logits.size(0) is num of prompts.

This implementation works if the num_sample_tokens is 1 when running the server without speculative decoding. If speculative decoding is enabled, num_sample_tokens will become num_speculative_decoding_tokens + 1. The code fail with error messages like

(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527] WorkerProc hit an exception.
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527] Traceback (most recent call last):
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/vllm/v1/executor/multiproc_executor.py", line 522, in worker_busy_loop
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     output = func(*args, **kwargs)
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     return func(*args, **kwargs)
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/vllm/v1/worker/gpu_worker.py", line 293, in execute_model
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     return func(*args, **kwargs)
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1292, in execute_model
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     logits = self.model.compute_logits(sample_hidden_states, None)
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/llama.py", line 590, in compute_logits
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/vllm/lora/layers.py", line 1186, in forward
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     return type(self.base_layer).forward(self, *args, **kwargs)
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/layers/logits_processor.py", line 71, in forward
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]   File "/usr/local/lib/python3.11/dist-packages/vllm/lora/layers.py", line 1169, in _get_logits
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527]     logits[:,
(VllmWorker rank=1 pid=19338) ERROR 07-08 19:45:36 [multiproc_executor.py:527] RuntimeError: The expanded size of the tensor (7) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [7, 256].  Tensor sizes: [2, 256]

This commit adjusts the dimension of AdapterMapping.prompt_mapping so that it stores LoRA id for each sampled token. If speculative decoding is disabled, its size will become the number of prompts and will store LoRA id for each prompt, which is the same as the original implementation.

Test Plan

Running the V1 server with both LoRA and Speculative Decoding enabled. Test scenarios:

  1. Run single base model request
  2. Run single adapter model request
  3. Run multiple base model requests in parallel
  4. Run multiple adapter model requests in parallel

Test Result

  • Run server
export VLLM_USE_V1=1
export VLLM_ENABLE_V1_MULTIPROCESSING=1

python3 -m vllm.entrypoints.openai.api_server \
     --model /tmp/model/llama-3-1-8b/ \
     --port 8080 \
     --max-num-seqs 4 \
     --tensor-parallel-size 1 \
     --gpu-memory-utilization 0.9 \
     --enable-lora \
     --lora-modules adapter=/tmp/model/llama-3-1-8b/adapter/ \
     --max-lora-rank 32 \
     --speculative_config '{"model": "/tmp/model/llama-3-1-8b/eagle_head/", "num_speculative_tokens": 5, "method": "eagle3"}'
  • Run single base model request
curl http://localhost:8080/v1/completions     -H "Content-Type: application/json"     -d '{
        "model": "/tmp/model/llama-3-1-8b/",
        "prompt": "San Francisco is a",
        "max_tokens": 32,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq

"text": " top tourist destination, and for good reason. The city is known for its iconic Golden Gate Bridge, steep hills, colorful Victorian homes, and vibrant cultural scene."
  • Run single adapter request
curl http://localhost:8080/v1/completions     -H "Content-Type: application/json"     -d '{
        "model": "/tmp/model/llama-3-1-8b/",
        "prompt": "San Francisco is a",
        "max_tokens": 32,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq

"text": " top tourist destination, and for good reason. The city is known for its iconic Golden Gate Bridge, steep hills, colorful Victorian homes, and vibrant cultural scene."
  • Run multiple requests in parallel
curl http://localhost:8080/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "adapter",
        "prompt": "San Francisco is a",
        "max_tokens": 600,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq & curl http://localhost:8080/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "adapter",
        "prompt": "San Francisco is a",
        "max_tokens": 600,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq & curl http://localhost:8080/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "adapter",
        "prompt": "San Francisco is a",
        "max_tokens": 600,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq & curl http://localhost:8080/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "adapter",
        "prompt": "San Francisco is a",
        "max_tokens": 600,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq &

See the dimension of logits and lora_logits match in logits computing.

logits shape: torch.Size([24, 128512])
lora_logits shape: torch.Size([24, 256])

(Optional) Documentation Update

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for LoRA with speculative decoding in the V1 engine, addressing a dimension mismatch issue that arises when speculative decoding is enabled. The changes involve adjusting the dimension of AdapterMapping.prompt_mapping to store LoRA IDs for each sampled token, ensuring compatibility with both enabled and disabled speculative decoding scenarios. The test plan includes running the V1 server with LoRA and speculative decoding enabled, covering single and multiple requests in parallel.

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind turning your scripts into tests, if not already present?

Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
… v1 on gpu

Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
@xiaohongchen1991 xiaohongchen1991 force-pushed the speculative-decoding-with-lora branch from 0fe83ca to e90736f Compare September 25, 2025 19:15
@mergify mergify bot added the tpu Related to Google TPUs label Sep 25, 2025
@mergify
Copy link

mergify bot commented Sep 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xiaohongchen1991.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 25, 2025
@xiaohongchen1991
Copy link
Contributor Author

xiaohongchen1991 commented Sep 25, 2025

Would you mind turning your scripts into tests, if not already present?

Have turned the script into a unit test.

pytest -s -v tests/v1/e2e/test_lora_with_spec_decode.py

succeeded

Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
@xiaohongchen1991 xiaohongchen1991 force-pushed the speculative-decoding-with-lora branch from 38138b0 to 6d96013 Compare October 7, 2025 17:49
@xiaohongchen1991
Copy link
Contributor Author

@NickLucche Can you help take a look

@mergify
Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xiaohongchen1991.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

# to max_batches * (num_speculative_decoding_tokens + 1).
self.prompt_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_batches, device=device
self.max_loras, max_num_batched_tokens, device=device
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xiaohongchen1991 . I have seen configurations with max_batches, max_num_batched_tokens set as 1024, 8192. In such cases, it looks like there is a constraint on how big num_speculative_decoding_tokens can be. I think we should add an assert like assert(max_num_batched_tokens >= max_batches * (num_speculative_decoding_tokens + 1)) so we catch out-of-bounds errors.
What do you think ? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, should we just use max_batches if spec_decode is disabled ? It might be useful when debugging issues.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @varun-sundar-rabindranath . Assert added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @li2haipeng !

@varun-sundar-rabindranath
Copy link
Contributor

varun-sundar-rabindranath commented Nov 7, 2025

Thanks @xiaohongchen1991 for the great work - The changes generally look good to me. The preparation of prompt_lora_mapping with sample_num_tokens to account for the draft tokens in LoRA sampler_ids and LoRA shrink and expand kernels look good.

I am a little out of sync with Spec. Decode - Particularly I am a bit confused as to what happens when sum(sample_num_tokens) exceeds max_num_batched_tokens - Most of the data structures in PunicaWrapperBase is max out at max_num_batched_tokens. Is this already handled elsewhere and not a possibility ?
cc @jeejeelee

@li2haipeng
Copy link
Contributor

li2haipeng commented Nov 7, 2025

Thanks @xiaohongchen1991 for the great work - The changes generally look good to me. The preparation of prompt_lora_mapping with sample_num_tokens to account for the draft tokens in LoRA sampler_ids and LoRA shrink and expand kernels look good.

I am a little out of sync with Spec. Decode - Particularly I am a bit confused as to what happens when sum(sample_num_tokens) exceeds max_num_batched_tokens - Most of the data structures in PunicaWrapperBase is max out at max_num_batched_tokens. Is this already handled elsewhere and not a possibility ? cc @jeejeelee

Not sure If I understand it correctly but this situation would be guarded after adding the assert, right?

@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 7, 2025
lora_requests: set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = (
input_batch.make_lora_inputs(num_scheduled_tokens)
input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@li2haipeng can you also add an assert after this line like,
assert(len(prompt_lora_mapping) <= self.max_num_batched_tokens)

My main concern is that, given we are doing

        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))

in gpu_input_batch.py :: make_lora_inputs()
I wonder if len(prompt_lora_mapping) would exceed max_num_batched_tokens . If this happens, I think we will catch it here

self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices)
, but I am not fully sure. Eitherway, an assert here would be very useful and would point to the direct cause.

Also, I think the interaction between max_batches * (num_speculative_decoding_tokens + 1) and max_num_batched_tokens should be captured and we should raise an error during engine startup when they are incompatible. For example, if the user creates an engine with,
LoRA + Spec Decode + max_num_seqs=512 + max_num_batched_tokens=1024 + num_speculative_decoding_tokens = 5
This will assert deep in the code - but it'll be much better to assert during startup (somewhere here

def create_engine_config(
) and have a suggestion for the users to increase the max_num_batched_tokens.

What do you think ?
cc @robertgshaw2-redhat

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Varun

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath Fixed. Thanks for your suggestion.

@varun-sundar-rabindranath
Copy link
Contributor

varun-sundar-rabindranath commented Nov 7, 2025

Fixed pre-commit and missing import issues from the previous rebase.

Synced up offline with some related folks. Here is the current status. This PR has conflicts with this recent merged commit from @andylolu2 and @jeejeelee. This caused the issue with TP>1 as mentioned by @li2haipeng .

This issue can be reproduced by running the added unit test

pytest -s -v tests/v1/e2e/test_lora_with_spec_decode.py

with TP=2. i.e., update to the following.

@pytest.mark.parametrize(
    "model_setup",
    [
        (
            "eagle3",
            "Qwen/Qwen3-1.7B",
            "AngelSlim/Qwen3-1.7B_eagle3",
            "premjatin/qwen-linear-algebra-coder",
            2,
        )
    ],
)

A temporary workaround is to set flag cudagraph_specialize_lora to False to disable this new feature introduced in the conflicting commit. Re-ran the unit test with TP>2 and it passed, meaning the output for config with/without speculative decoding is the same when running inference on lora adapter. So, there is no accuracy issue with this workaround.

@li2haipeng can you record this restriction too please. i.e. set cudagraph_specialize_lora to false explicitly with a logger message conveying lora + spec_decode + cudagraph_specialize_lora + tp is not supported at the moment.

@dcmaddix
Copy link
Contributor

dcmaddix commented Nov 7, 2025

Fixed pre-commit and missing import issues from the previous rebase.
Synced up offline with some related folks. Here is the current status. This PR has conflicts with this recent merged commit from @andylolu2 and @jeejeelee. This caused the issue with TP>1 as mentioned by @li2haipeng .
This issue can be reproduced by running the added unit test

pytest -s -v tests/v1/e2e/test_lora_with_spec_decode.py

with TP=2. i.e., update to the following.

@pytest.mark.parametrize(
    "model_setup",
    [
        (
            "eagle3",
            "Qwen/Qwen3-1.7B",
            "AngelSlim/Qwen3-1.7B_eagle3",
            "premjatin/qwen-linear-algebra-coder",
            2,
        )
    ],
)

A temporary workaround is to set flag cudagraph_specialize_lora to False to disable this new feature introduced in the conflicting commit. Re-ran the unit test with TP>2 and it passed, meaning the output for config with/without speculative decoding is the same when running inference on lora adapter. So, there is no accuracy issue with this workaround.

@li2haipeng can you record this restriction too please. i.e. set cudagraph_specialize_lora to false explicitly with a logger message conveying lora + spec_decode + cudagraph_specialize_lora + tp is not supported at the moment.

@varun-sundar-rabindranath @li2haipeng we don't need to set it to False because this PR fixed the issue: #28318.

"Qwen/Qwen3-1.7B",
"AngelSlim/Qwen3-1.7B_eagle3",
"premjatin/qwen-linear-algebra-coder",
1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given the issues around TP, it'd be good to add a TP = 2 test as well. But it can be a fast-follow after @28318 lands. Thanks.

@simon-mo simon-mo dismissed NickLucche’s stale review November 7, 2025 23:25

script -> test conversion done

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @xiaohongchen1991 @li2haipeng @dcmaddix for working on this ❤️

), (
"Consider increasing max_num_batched_tokens or "
"decreasing num_speculative_tokens"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : Can you turn it into a ValueError to stay consistent with the error-raising mechanism in this file. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp post varun's review

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) November 7, 2025 23:33
auto-merge was automatically disabled November 7, 2025 23:42

Head branch was pushed to by a user without write access

@jeejeelee jeejeelee enabled auto-merge (squash) November 8, 2025 00:12
@jeejeelee jeejeelee merged commit d0c7792 into vllm-project:main Nov 8, 2025
55 checks passed
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Nov 13, 2025
…llm-project#21068)

Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Danielle Robinson <dcmaddix@gmail.com>
Co-authored-by: Haipeng Li <li2haipeng@gmail.com>
Co-authored-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
wangxiyuan added a commit to vllm-project/vllm-ascend that referenced this pull request Nov 26, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

 
What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work 
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
Co-authored-by: shen-shanshan <467638484@qq.com>


- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Kurumi5210 pushed a commit to lidenghui1110/vllm-ascend that referenced this pull request Nov 26, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
Co-authored-by: shen-shanshan <467638484@qq.com>

- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: Kurumi5210 <Jaychou1620@Gmail.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
…llm-project#21068)

Signed-off-by: Sean Chen <xiaohong_chen1991@hotmail.com>
Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Danielle Robinson <dcmaddix@gmail.com>
Co-authored-by: Haipeng Li <li2haipeng@gmail.com>
Co-authored-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants