Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions tests/v1/e2e/test_lora_with_spec_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This script contains:
1. test lora with speculative decoding for batch inference
"""

import random

import numpy as np
import pytest
import torch

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform

LORA_TEST_PROMPT_MAP: dict[str, str] = {}

LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """
### INSTRUCTION:
You are an AI assistant that generates Python code to solve linear
algebra problems.
### PROBLEM:
Find the eigenvalues and eigenvectors of the following 3x3 matrix:
[[3, 2, 0],
[2, 3, 0],
[0, 0, 2]]
### OUTPUT FORMAT (STRICT):
Numbers should be represented as integers only.
### PYTHON SOLUTION:
"""

SEED = 42


@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
@pytest.mark.parametrize(
"model_setup",
[
(
"eagle3",
"Qwen/Qwen3-1.7B",
"AngelSlim/Qwen3-1.7B_eagle3",
"premjatin/qwen-linear-algebra-coder",
1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given the issues around TP, it'd be good to add a TP = 2 test as well. But it can be a fast-follow after @28318 lands. Thanks.

)
],
)
def test_batch_inference_correctness(
monkeypatch: pytest.MonkeyPatch,
model_setup: tuple[str, str, str, str, int],
):
"""
Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora.
Should be the same and no failure when doing batch inference.
model_setup: (method, model_name, spec_model_name, lora_path, tp_size)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

# Disable randomness
m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

method, model_name, spec_model_name, lora_path, tp_size = model_setup

# without speculative decoding
ref_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
max_model_len=2048,
max_num_seqs=4,
enable_lora=True,
max_loras=1,
max_cpu_loras=1,
max_lora_rank=16,
)

prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100
lora_request = LoRARequest("adapter", 1, lora_path)
sampling_params = SamplingParams(
temperature=0.0, top_p=1.0, top_k=-1, seed=SEED, max_tokens=128
)

ref_outputs = ref_llm.generate(
prompts, sampling_params, lora_request=lora_request
)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

lora_spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
},
max_model_len=2048,
max_num_seqs=4,
enable_lora=True,
max_loras=1,
max_cpu_loras=1,
max_lora_rank=16,
)

lora_spec_outputs = lora_spec_llm.generate(
prompts, sampling_params, lora_request=lora_request
)

matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 90% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
print(f"match ratio: {matches}/{len(ref_outputs)}")
assert matches > int(0.90 * len(ref_outputs))
del lora_spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
14 changes: 14 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,20 @@ def create_engine_config(
else None
)

if (
lora_config is not None
and speculative_config is not None
and scheduler_config.max_num_batched_tokens
< (
scheduler_config.max_num_seqs
* (speculative_config.num_speculative_tokens + 1)
)
):
raise ValueError(
"Consider increasing max_num_batched_tokens or "
"decreasing num_speculative_tokens"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : Can you turn it into a ValueError to stay consistent with the error-raising mechanism in this file. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


# bitsandbytes pre-quantized model need a specific model loader
if model_config.quantization == "bitsandbytes":
self.quantization = self.load_format = "bitsandbytes"
Expand Down
6 changes: 5 additions & 1 deletion vllm/lora/punica_wrapper/punica_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ def __init__(
self.max_loras, max_num_batched_tokens, device=device
)

# When speculative decoding is enabled, max_num_samples is
# max_batches * (num_speculative_decoding_tokens + 1).
# This line can be optimized by replacing max_num_batched_tokens
# to max_batches * (num_speculative_decoding_tokens + 1).
self.prompt_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_batches, device=device
self.max_loras, max_num_batched_tokens, device=device
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xiaohongchen1991 . I have seen configurations with max_batches, max_num_batched_tokens set as 1024, 8192. In such cases, it looks like there is a constraint on how big num_speculative_decoding_tokens can be. I think we should add an assert like assert(max_num_batched_tokens >= max_batches * (num_speculative_decoding_tokens + 1)) so we catch out-of-bounds errors.
What do you think ? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, should we just use max_batches if spec_decode is disabled ? It might be useful when debugging issues.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @varun-sundar-rabindranath . Assert added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @li2haipeng !

)

def update_metadata(
Expand Down
10 changes: 6 additions & 4 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,22 +859,24 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)

def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
where, prompt_lora_mapping[i] is the LoRA id to use for the ith
sampled token.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""

req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))

active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values()
)
Expand Down
18 changes: 15 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,7 @@ def _prepare_inputs(
logits_indices = query_start_loc[1:] - 1
num_draft_tokens = None
spec_decode_metadata = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
Expand All @@ -1294,7 +1295,7 @@ def _prepare_inputs(
num_draft_tokens, cu_num_tokens
)
logits_indices = spec_decode_metadata.logits_indices

num_sampled_tokens = num_draft_tokens + 1
# For DECODE only cuda graph of some attention backends (e.g., GDN).
self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
Expand Down Expand Up @@ -1445,7 +1446,13 @@ def _prepare_inputs(

# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
assert (
np.sum(num_sampled_tokens)
<= self.vllm_config.scheduler_config.max_num_batched_tokens
)
self.set_active_loras(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)

return (
attn_metadata,
Expand Down Expand Up @@ -3390,6 +3397,7 @@ def _dummy_run(
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)

# Disable DP padding when running eager
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
Expand Down Expand Up @@ -3485,7 +3493,11 @@ def _dummy_run(
attn_metadata[layer_name] = attn_metadata_i

with self.maybe_dummy_run_with_lora(
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
self.lora_config,
num_scheduled_tokens,
num_sampled_tokens,
activate_lora,
remove_lora,
):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_after_padding <= self.max_num_tokens
Expand Down
25 changes: 19 additions & 6 deletions vllm/v1/worker/lora_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def load_lora_model(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model."
)

# Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager(
vllm_config,
Expand Down Expand Up @@ -70,13 +69,19 @@ def _ensure_lora_enabled(self) -> None:
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")

def set_active_loras(
self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray
self,
input_batch: InputBatch,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
) -> None:
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)

prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens)
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
lora_requests: set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = (
input_batch.make_lora_inputs(num_scheduled_tokens)
input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@li2haipeng can you also add an assert after this line like,
assert(len(prompt_lora_mapping) <= self.max_num_batched_tokens)

My main concern is that, given we are doing

        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))

in gpu_input_batch.py :: make_lora_inputs()
I wonder if len(prompt_lora_mapping) would exceed max_num_batched_tokens . If this happens, I think we will catch it here

self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices)
, but I am not fully sure. Eitherway, an assert here would be very useful and would point to the direct cause.

Also, I think the interaction between max_batches * (num_speculative_decoding_tokens + 1) and max_num_batched_tokens should be captured and we should raise an error during engine startup when they are incompatible. For example, if the user creates an engine with,
LoRA + Spec Decode + max_num_seqs=512 + max_num_batched_tokens=1024 + num_speculative_decoding_tokens = 5
This will assert deep in the code - but it'll be much better to assert during startup (somewhere here

def create_engine_config(
) and have a suggestion for the users to increase the max_num_batched_tokens.

What do you think ?
cc @robertgshaw2-redhat

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Varun

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath Fixed. Thanks for your suggestion.

)
return self._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests
Expand Down Expand Up @@ -123,8 +128,12 @@ def maybe_select_dummy_loras(
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
activate_lora: bool = True,
):
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)

if lora_config is None:
yield
else:
Expand All @@ -143,6 +152,9 @@ def maybe_select_dummy_loras(
else:
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)

# Make sample lora mapping
sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens)

# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)

Expand All @@ -157,7 +169,7 @@ def maybe_select_dummy_loras(
}

self._set_active_loras(
tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
)

yield
Expand All @@ -167,13 +179,14 @@ def maybe_dummy_run_with_lora(
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray,
activate_lora: bool = True,
remove_lora: bool = True,
):
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, activate_lora
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
),
):
yield
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/tpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)

def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
Expand Down