From 202c2ccd5bd8842ae8a5597929b4c6c432fe860b Mon Sep 17 00:00:00 2001 From: Sean Chen Date: Wed, 16 Jul 2025 16:58:10 +0000 Subject: [PATCH 01/10] [Lora][Spec Decode] support LoRA with speculative decoding for v1 on gpu Signed-off-by: Sean Chen --- vllm/lora/punica_wrapper/punica_gpu.py | 9 ++--- vllm/v1/worker/gpu_input_batch.py | 10 +++--- vllm/v1/worker/gpu_model_runner.py | 9 +++-- vllm/v1/worker/lora_model_runner_mixin.py | 40 ++++++++++++++++++----- vllm/v1/worker/tpu_input_batch.py | 2 +- 5 files changed, 50 insertions(+), 20 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 2db0e9fee142..1e711b654807 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -11,7 +11,6 @@ import torch -import vllm.envs as envs from vllm.lora.layers import LoRAMapping from vllm.triton_utils import HAS_TRITON @@ -45,10 +44,12 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, # here), V0 captures the graph as if max_num_seqs is set to # the capture size. # V1 doesn't have this problem and always respects max_num_seqs. - max_num_prompts = (max_batches - if envs.VLLM_USE_V1 else max_num_batched_tokens) + # When speculative decoding is enabled, max_num_samples is + # max_batches * (num_speculative_decoding_tokens + 1). + # This line can be optimized. + max_num_samples = max_num_batched_tokens self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras, - max_num_prompts, + max_num_samples, device=device) def update_metadata(self, mapping: LoRAMapping, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 79a392337574..5d8d61df6426 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -802,21 +802,23 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: non_blocking=True) def make_lora_inputs( - self, num_scheduled_tokens: np.ndarray + self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: """ Given the num_scheduled_tokens for each request in the batch, return datastructures used to activate the current LoRAs. Returns: - 1. prompt_lora_mapping: A tuple of size self.num_reqs where, - prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens) + where, prompt_lora_mapping[i] is the LoRA id to use for the ith + sampled token. 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) where, token_lora_mapping[i] is the LoRA id to use for ith token. 3. lora_requests: Set of relevant LoRA requests. """ req_lora_mapping = self.request_lora_mapping[:self.num_reqs] - prompt_lora_mapping = tuple(req_lora_mapping) + prompt_lora_mapping = tuple( + req_lora_mapping.repeat(num_sampled_tokens)) token_lora_mapping = tuple( req_lora_mapping.repeat(num_scheduled_tokens)) active_lora_requests: set[LoRARequest] = set( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ee339e22cea9..a82ff22b80d5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1085,6 +1085,7 @@ def _prepare_inputs( logits_indices = query_start_loc[1:] - 1 num_draft_tokens = None spec_decode_metadata = None + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) else: # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all @@ -1098,6 +1099,7 @@ def _prepare_inputs( spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices + num_sampled_tokens = num_draft_tokens + 1 self.num_draft_tokens.np[:num_reqs] = num_draft_tokens self.num_draft_tokens.np[num_reqs:].fill(0) self.num_draft_tokens.copy_to_gpu() @@ -1230,7 +1232,8 @@ def _prepare_inputs( # Hot-Swap lora model if self.lora_config: - self.set_active_loras(self.input_batch, num_scheduled_tokens) + self.set_active_loras(self.input_batch, num_scheduled_tokens, + num_sampled_tokens) return (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata, @@ -2915,6 +2918,7 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) ubatch_slices = None num_tokens_after_padding = None @@ -3006,7 +3010,8 @@ def _dummy_run( attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens, remove_lora): + num_scheduled_tokens, + num_sampled_tokens, remove_lora): model_kwargs = self._init_model_kwargs(num_tokens) if (self.supports_mm_inputs and not self.model_config.is_encoder_decoder): diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index e416f50322f4..fb7376539912 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -70,15 +70,24 @@ def _ensure_lora_enabled(self) -> None: raise RuntimeError( "LoRA is not enabled. Use --enable-lora to enable LoRA.") - def set_active_loras(self, input_batch: InputBatch, - num_scheduled_tokens: np.ndarray) -> None: - - prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs + def set_active_loras( + self, + input_batch: InputBatch, + num_scheduled_tokens: np.ndarray, + num_sampled_tokens: Optional[np.ndarray] = None) -> None: + + if num_sampled_tokens is None: + num_sampled_tokens = np.ones_like(num_scheduled_tokens, + dtype=np.int32) + + prompt_lora_mapping: tuple[int, + ...] # of size np.sum(num_sampled_tokens) token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) lora_requests: set[LoRARequest] prompt_lora_mapping, token_lora_mapping, lora_requests = \ - input_batch.make_lora_inputs(num_scheduled_tokens) + input_batch.make_lora_inputs(num_scheduled_tokens, + num_sampled_tokens) return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, lora_requests) @@ -116,8 +125,15 @@ def maybe_setup_dummy_loras(self, self.lora_manager.remove_all_adapters() @contextmanager - def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], - num_scheduled_tokens: np.ndarray): + def maybe_select_dummy_loras( + self, + lora_config: Optional[LoRAConfig], + num_scheduled_tokens: np.ndarray, + num_sampled_tokens: Optional[np.ndarray] = None): + if num_sampled_tokens is None: + num_sampled_tokens = np.ones_like(num_scheduled_tokens, + dtype=np.int32) + if lora_config is None: yield else: @@ -132,6 +148,10 @@ def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1 + # Make sample lora mapping + sample_lora_mapping = np.repeat(prompt_lora_mapping, + num_sampled_tokens) + # Make token lora mapping token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) @@ -144,7 +164,7 @@ def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], for lora_id in range(1, num_loras + 1) } - self._set_active_loras(tuple(prompt_lora_mapping), + self._set_active_loras(tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests) yield @@ -153,11 +173,13 @@ def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], def maybe_dummy_run_with_lora(self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray, + num_sampled_tokens: np.ndarray, remove_lora: bool = True): with ( self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_select_dummy_loras(lora_config, - num_scheduled_tokens), + num_scheduled_tokens, + num_sampled_tokens), ): yield diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 4cd0ac352de0..97d0773d38e3 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -522,7 +522,7 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: non_blocking=True) def make_lora_inputs( - self, num_scheduled_tokens: np.ndarray + self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: """ Given the num_scheduled_tokens for each request in the batch, return From e90736f60afa775225419d89fdad6b9fe6fccf00 Mon Sep 17 00:00:00 2001 From: Sean Chen Date: Thu, 25 Sep 2025 15:12:43 -0400 Subject: [PATCH 02/10] fixup! [Lora][Spec Decode] support LoRA with speculative decoding for v1 on gpu Signed-off-by: Sean Chen --- tests/v1/e2e/test_lora_with_spec_decode.py | 111 +++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/v1/e2e/test_lora_with_spec_decode.py diff --git a/tests/v1/e2e/test_lora_with_spec_decode.py b/tests/v1/e2e/test_lora_with_spec_decode.py new file mode 100644 index 000000000000..efdaa4560a2a --- /dev/null +++ b/tests/v1/e2e/test_lora_with_spec_decode.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This script contains: +1. test lora with speculative decoding for batch inference +""" +import pytest +import torch + +from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform + +LORA_TEST_PROMPT_MAP: dict[str, str] = {} + +LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """ +### INSTRUCTION: +You are an AI assistant that generates Python code to solve linear +algebra problems. + +### PROBLEM: +Find the eigenvalues and eigenvectors of the following 3x3 matrix: +[[4, 0, 1], + [-2, 1, 0], + [-2, 0, 1]] + +### PYTHON SOLUTION: +""" + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="CUDA not available") +@pytest.mark.parametrize( + "model_setup", + [("eagle3", "Qwen/Qwen3-1.7B", "AngelSlim/Qwen3-1.7B_eagle3", + "premjatin/qwen-linear-algebra-coder", 1)]) +def test_batch_inference_correctness( + monkeypatch: pytest.MonkeyPatch, + model_setup: tuple[str, str, str, str, int], +): + ''' + Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora. + Should be the same and no failure when doing batch inference. + model_setup: (method, model_name, spec_model_name, lora_path, tp_size) + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + method, model_name, spec_model_name, lora_path, tp_size = model_setup + + # without speculative decoding + ref_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + max_model_len=2048, + max_num_seqs=4, + enable_lora=True, + max_loras=1, + max_cpu_loras=1, + max_lora_rank=16, + ) + + prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 4 + lora_request = LoRARequest("adapter", 1, lora_path) + sampling_params = SamplingParams(temperature=0, max_tokens=128) + + ref_outputs = ref_llm.generate(prompts, + sampling_params, + lora_request=lora_request) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + lora_spec_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": 2048, + }, + max_model_len=2048, + max_num_seqs=4, + enable_lora=True, + max_loras=1, + max_cpu_loras=1, + max_lora_rank=16, + ) + + lora_spec_outputs = lora_spec_llm.generate(prompts, + sampling_params, + lora_request=lora_request) + + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + assert misses == 0 + del lora_spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() From 6d960135d8c856e9901141faf2953dc1e4d9ba67 Mon Sep 17 00:00:00 2001 From: Sean Chen Date: Tue, 7 Oct 2025 13:49:25 -0400 Subject: [PATCH 03/10] fixup! Merge branch 'main' into speculative-decoding-with-lora Signed-off-by: Sean Chen --- tests/v1/e2e/test_lora_with_spec_decode.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/v1/e2e/test_lora_with_spec_decode.py b/tests/v1/e2e/test_lora_with_spec_decode.py index 900103515c8b..d7b354df9ae0 100644 --- a/tests/v1/e2e/test_lora_with_spec_decode.py +++ b/tests/v1/e2e/test_lora_with_spec_decode.py @@ -70,7 +70,7 @@ def test_batch_inference_correctness( max_lora_rank=16, ) - prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 4 + prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100 lora_request = LoRARequest("adapter", 1, lora_path) sampling_params = SamplingParams(temperature=0, max_tokens=128) @@ -113,7 +113,9 @@ def test_batch_inference_correctness( print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") - assert misses == 0 + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) del lora_spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() From 5f74f23bcd4e0bcba61c112e7c64afb0095a6cc1 Mon Sep 17 00:00:00 2001 From: Sean Chen Date: Wed, 5 Nov 2025 18:26:55 -0500 Subject: [PATCH 04/10] fixup! Merge branch 'main' into speculative-decoding-with-lora Signed-off-by: Sean Chen --- tests/v1/e2e/test_lora_with_spec_decode.py | 32 ++++++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 6 +++- vllm/v1/worker/lora_model_runner_mixin.py | 9 +++--- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/tests/v1/e2e/test_lora_with_spec_decode.py b/tests/v1/e2e/test_lora_with_spec_decode.py index d7b354df9ae0..14532f279544 100644 --- a/tests/v1/e2e/test_lora_with_spec_decode.py +++ b/tests/v1/e2e/test_lora_with_spec_decode.py @@ -5,6 +5,9 @@ 1. test lora with speculative decoding for batch inference """ +import random + +import numpy as np import pytest import torch @@ -22,13 +25,18 @@ ### PROBLEM: Find the eigenvalues and eigenvectors of the following 3x3 matrix: -[[4, 0, 1], - [-2, 1, 0], - [-2, 0, 1]] +[[3, 2, 0], + [2, 3, 0], + [0, 0, 2]] + +### OUTPUT FORMAT (STRICT): +Numbers should be represented as integers only. ### PYTHON SOLUTION: """ +SEED = 42 + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") @pytest.mark.parametrize( @@ -55,6 +63,15 @@ def test_batch_inference_correctness( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + # Disable randomness + m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + torch.manual_seed(SEED) + np.random.seed(SEED) + random.seed(SEED) + torch.cuda.manual_seed_all(SEED) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + method, model_name, spec_model_name, lora_path, tp_size = model_setup # without speculative decoding @@ -72,7 +89,9 @@ def test_batch_inference_correctness( prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100 lora_request = LoRARequest("adapter", 1, lora_path) - sampling_params = SamplingParams(temperature=0, max_tokens=128) + sampling_params = SamplingParams( + temperature=0.0, top_p=1.0, top_k=-1, seed=SEED, max_tokens=128 + ) ref_outputs = ref_llm.generate( prompts, sampling_params, lora_request=lora_request @@ -113,9 +132,10 @@ def test_batch_inference_correctness( print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") - # Heuristic: expect at least 66% of the prompts to match exactly + # Heuristic: expect at least 90% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.66 * len(ref_outputs)) + print(f"match ratio: {matches}/{len(ref_outputs)}") + assert matches > int(0.90 * len(ref_outputs)) del lora_spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2d7972e181e1..28e1cf6d4043 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3395,7 +3395,11 @@ def _dummy_run( attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora( - self.lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora, remove_lora + self.lora_config, + num_scheduled_tokens, + num_sampled_tokens, + activate_lora, + remove_lora, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_after_padding <= self.max_num_tokens diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 965cfccdadfe..9216119f1e94 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -73,7 +73,7 @@ def set_active_loras( self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray, - num_sampled_tokens: Optional[np.ndarray] = None, + num_sampled_tokens: np.ndarray | None = None, ) -> None: if num_sampled_tokens is None: num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32) @@ -127,9 +127,9 @@ def maybe_setup_dummy_loras( @contextmanager def maybe_select_dummy_loras( self, - lora_config: Optional[LoRAConfig], + lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray, - num_sampled_tokens: Optional[np.ndarray] = None, + num_sampled_tokens: np.ndarray | None = None, activate_lora: bool = True, ): if num_sampled_tokens is None: @@ -187,8 +187,7 @@ def maybe_dummy_run_with_lora( with ( self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens, num_sampled_tokens, - activate_lora + lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora ), ): yield From f927b9fc2d3cad2664301f40ff763c62abe7b70f Mon Sep 17 00:00:00 2001 From: Haipeng Li Date: Fri, 7 Nov 2025 19:44:49 +0000 Subject: [PATCH 05/10] add assert --- vllm/v1/worker/lora_model_runner_mixin.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 9216119f1e94..7438f61b9702 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -38,7 +38,11 @@ def load_lora_model( "Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model." ) - + if vllm_config.speculative_config: + assert vllm_config.scheduler_config.max_num_batched_tokens >= ( + vllm_config.scheduler_config.max_num_seqs + * (vllm_config.speculative_config.num_speculative_tokens + 1) + ), "Consider increasing max_num_batched_tokens or decreasing num_speculative_tokens" \ # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( vllm_config, From 605de8a47193dd22d6640656a6da14282ba84722 Mon Sep 17 00:00:00 2001 From: Haipeng Li Date: Fri, 7 Nov 2025 20:12:01 +0000 Subject: [PATCH 06/10] pre-commit fix --- vllm/v1/worker/lora_model_runner_mixin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 7438f61b9702..6ceb3e67b73f 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -42,7 +42,10 @@ def load_lora_model( assert vllm_config.scheduler_config.max_num_batched_tokens >= ( vllm_config.scheduler_config.max_num_seqs * (vllm_config.speculative_config.num_speculative_tokens + 1) - ), "Consider increasing max_num_batched_tokens or decreasing num_speculative_tokens" \ + ), ( + "Consider increasing max_num_batched_tokens or " + "decreasing num_speculative_tokens" + ) # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( vllm_config, From 9fe779c7148c9b7fdbdd830e42f9b1f449210a0c Mon Sep 17 00:00:00 2001 From: Haipeng Li Date: Fri, 7 Nov 2025 22:05:08 +0000 Subject: [PATCH 07/10] fix comment --- vllm/engine/arg_utils.py | 9 +++++++++ vllm/v1/worker/gpu_model_runner.py | 4 ++++ vllm/v1/worker/lora_model_runner_mixin.py | 8 -------- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f1a6c0716e4c..e5cb98875cbb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1574,6 +1574,15 @@ def create_engine_config( else None ) + if lora_config is not None and speculative_config is not None: + assert scheduler_config.max_num_batched_tokens >= ( + scheduler_config.max_num_seqs + * (speculative_config.num_speculative_tokens + 1) + ), ( + "Consider increasing max_num_batched_tokens or " + "decreasing num_speculative_tokens" + ) + # bitsandbytes pre-quantized model need a specific model loader if model_config.quantization == "bitsandbytes": self.quantization = self.load_format = "bitsandbytes" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 03f0245ee369..8a03b23facc3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1446,6 +1446,10 @@ def _prepare_inputs( # Hot-Swap lora model if self.lora_config: + assert ( + np.sum(num_sampled_tokens) + <= self.vllm_config.scheduler_config.max_num_batched_tokens + ) self.set_active_loras( self.input_batch, num_scheduled_tokens, num_sampled_tokens ) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 6ceb3e67b73f..37abe5649460 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -38,14 +38,6 @@ def load_lora_model( "Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model." ) - if vllm_config.speculative_config: - assert vllm_config.scheduler_config.max_num_batched_tokens >= ( - vllm_config.scheduler_config.max_num_seqs - * (vllm_config.speculative_config.num_speculative_tokens + 1) - ), ( - "Consider increasing max_num_batched_tokens or " - "decreasing num_speculative_tokens" - ) # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( vllm_config, From a0609d99395ccb5476d046793ed9237815d8203d Mon Sep 17 00:00:00 2001 From: Haipeng Li Date: Fri, 7 Nov 2025 22:37:31 +0000 Subject: [PATCH 08/10] set cudagraph_specialize_lora to false when using lora+eagle --- vllm/engine/arg_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e5cb98875cbb..6317d13493e6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1575,6 +1575,13 @@ def create_engine_config( ) if lora_config is not None and speculative_config is not None: + if self.compilation_config.cudagraph_specialize_lora: + logger.warning_once( + "Currenlty cudagraph_specialize_lora is not supported " + "when using both LoRA and speculative decoding. " + "Setting cudagraph_specialize_lora to False." + ) + self.compilation_config.cudagraph_specialize_lora = False assert scheduler_config.max_num_batched_tokens >= ( scheduler_config.max_num_seqs * (speculative_config.num_speculative_tokens + 1) From 5e23647a651487c21cddb998ca7440e9ea444cdb Mon Sep 17 00:00:00 2001 From: Haipeng Li Date: Fri, 7 Nov 2025 23:07:07 +0000 Subject: [PATCH 09/10] remove warning --- vllm/engine/arg_utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6317d13493e6..e5cb98875cbb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1575,13 +1575,6 @@ def create_engine_config( ) if lora_config is not None and speculative_config is not None: - if self.compilation_config.cudagraph_specialize_lora: - logger.warning_once( - "Currenlty cudagraph_specialize_lora is not supported " - "when using both LoRA and speculative decoding. " - "Setting cudagraph_specialize_lora to False." - ) - self.compilation_config.cudagraph_specialize_lora = False assert scheduler_config.max_num_batched_tokens >= ( scheduler_config.max_num_seqs * (speculative_config.num_speculative_tokens + 1) From 0d78c2549fa53b216eb843401f0d74c49a1d0ff3 Mon Sep 17 00:00:00 2001 From: Haipeng Li Date: Fri, 7 Nov 2025 23:42:12 +0000 Subject: [PATCH 10/10] change to ValueError --- vllm/engine/arg_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e5cb98875cbb..342da0150a7c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1574,11 +1574,16 @@ def create_engine_config( else None ) - if lora_config is not None and speculative_config is not None: - assert scheduler_config.max_num_batched_tokens >= ( + if ( + lora_config is not None + and speculative_config is not None + and scheduler_config.max_num_batched_tokens + < ( scheduler_config.max_num_seqs * (speculative_config.num_speculative_tokens + 1) - ), ( + ) + ): + raise ValueError( "Consider increasing max_num_batched_tokens or " "decreasing num_speculative_tokens" )