diff --git a/tests/v1/e2e/test_lora_with_spec_decode.py b/tests/v1/e2e/test_lora_with_spec_decode.py new file mode 100644 index 000000000000..14532f279544 --- /dev/null +++ b/tests/v1/e2e/test_lora_with_spec_decode.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This script contains: +1. test lora with speculative decoding for batch inference +""" + +import random + +import numpy as np +import pytest +import torch + +from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform + +LORA_TEST_PROMPT_MAP: dict[str, str] = {} + +LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """ +### INSTRUCTION: +You are an AI assistant that generates Python code to solve linear +algebra problems. + +### PROBLEM: +Find the eigenvalues and eigenvectors of the following 3x3 matrix: +[[3, 2, 0], + [2, 3, 0], + [0, 0, 2]] + +### OUTPUT FORMAT (STRICT): +Numbers should be represented as integers only. + +### PYTHON SOLUTION: +""" + +SEED = 42 + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") +@pytest.mark.parametrize( + "model_setup", + [ + ( + "eagle3", + "Qwen/Qwen3-1.7B", + "AngelSlim/Qwen3-1.7B_eagle3", + "premjatin/qwen-linear-algebra-coder", + 1, + ) + ], +) +def test_batch_inference_correctness( + monkeypatch: pytest.MonkeyPatch, + model_setup: tuple[str, str, str, str, int], +): + """ + Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora. + Should be the same and no failure when doing batch inference. + model_setup: (method, model_name, spec_model_name, lora_path, tp_size) + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + # Disable randomness + m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + torch.manual_seed(SEED) + np.random.seed(SEED) + random.seed(SEED) + torch.cuda.manual_seed_all(SEED) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + method, model_name, spec_model_name, lora_path, tp_size = model_setup + + # without speculative decoding + ref_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + max_model_len=2048, + max_num_seqs=4, + enable_lora=True, + max_loras=1, + max_cpu_loras=1, + max_lora_rank=16, + ) + + prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100 + lora_request = LoRARequest("adapter", 1, lora_path) + sampling_params = SamplingParams( + temperature=0.0, top_p=1.0, top_k=-1, seed=SEED, max_tokens=128 + ) + + ref_outputs = ref_llm.generate( + prompts, sampling_params, lora_request=lora_request + ) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + lora_spec_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": 2048, + }, + max_model_len=2048, + max_num_seqs=4, + enable_lora=True, + max_loras=1, + max_cpu_loras=1, + max_lora_rank=16, + ) + + lora_spec_outputs = lora_spec_llm.generate( + prompts, sampling_params, lora_request=lora_request + ) + + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 90% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + print(f"match ratio: {matches}/{len(ref_outputs)}") + assert matches > int(0.90 * len(ref_outputs)) + del lora_spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f1a6c0716e4c..342da0150a7c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1574,6 +1574,20 @@ def create_engine_config( else None ) + if ( + lora_config is not None + and speculative_config is not None + and scheduler_config.max_num_batched_tokens + < ( + scheduler_config.max_num_seqs + * (speculative_config.num_speculative_tokens + 1) + ) + ): + raise ValueError( + "Consider increasing max_num_batched_tokens or " + "decreasing num_speculative_tokens" + ) + # bitsandbytes pre-quantized model need a specific model loader if model_config.quantization == "bitsandbytes": self.quantization = self.load_format = "bitsandbytes" diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 1bb80e516d3f..ede50a48af98 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -51,8 +51,12 @@ def __init__( self.max_loras, max_num_batched_tokens, device=device ) + # When speculative decoding is enabled, max_num_samples is + # max_batches * (num_speculative_decoding_tokens + 1). + # This line can be optimized by replacing max_num_batched_tokens + # to max_batches * (num_speculative_decoding_tokens + 1). self.prompt_mapping_meta = LoRAKernelMeta.make( - self.max_loras, max_batches, device=device + self.max_loras, max_num_batched_tokens, device=device ) def update_metadata( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index fe834db115e7..5afa68c3fca6 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -859,22 +859,24 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) def make_lora_inputs( - self, num_scheduled_tokens: np.ndarray + self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: """ Given the num_scheduled_tokens for each request in the batch, return datastructures used to activate the current LoRAs. Returns: - 1. prompt_lora_mapping: A tuple of size self.num_reqs where, - prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens) + where, prompt_lora_mapping[i] is the LoRA id to use for the ith + sampled token. 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) where, token_lora_mapping[i] is the LoRA id to use for ith token. 3. lora_requests: Set of relevant LoRA requests. """ req_lora_mapping = self.request_lora_mapping[: self.num_reqs] - prompt_lora_mapping = tuple(req_lora_mapping) + prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens)) token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens)) + active_lora_requests: set[LoRARequest] = set( self.lora_id_to_lora_request.values() ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 91c8efc17feb..8a03b23facc3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1268,6 +1268,7 @@ def _prepare_inputs( logits_indices = query_start_loc[1:] - 1 num_draft_tokens = None spec_decode_metadata = None + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) else: # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all @@ -1294,7 +1295,7 @@ def _prepare_inputs( num_draft_tokens, cu_num_tokens ) logits_indices = spec_decode_metadata.logits_indices - + num_sampled_tokens = num_draft_tokens + 1 # For DECODE only cuda graph of some attention backends (e.g., GDN). self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens self.num_decode_draft_tokens.np[num_reqs:].fill(-1) @@ -1445,7 +1446,13 @@ def _prepare_inputs( # Hot-Swap lora model if self.lora_config: - self.set_active_loras(self.input_batch, num_scheduled_tokens) + assert ( + np.sum(num_sampled_tokens) + <= self.vllm_config.scheduler_config.max_num_batched_tokens + ) + self.set_active_loras( + self.input_batch, num_scheduled_tokens, num_sampled_tokens + ) return ( attn_metadata, @@ -3390,6 +3397,7 @@ def _dummy_run( assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) # Disable DP padding when running eager allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -3485,7 +3493,11 @@ def _dummy_run( attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora( - self.lora_config, num_scheduled_tokens, activate_lora, remove_lora + self.lora_config, + num_scheduled_tokens, + num_sampled_tokens, + activate_lora, + remove_lora, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_after_padding <= self.max_num_tokens diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 372bc0a05673..37abe5649460 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -38,7 +38,6 @@ def load_lora_model( "Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model." ) - # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( vllm_config, @@ -70,13 +69,19 @@ def _ensure_lora_enabled(self) -> None: raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.") def set_active_loras( - self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray + self, + input_batch: InputBatch, + num_scheduled_tokens: np.ndarray, + num_sampled_tokens: np.ndarray | None = None, ) -> None: - prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs + if num_sampled_tokens is None: + num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32) + + prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens) token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) lora_requests: set[LoRARequest] prompt_lora_mapping, token_lora_mapping, lora_requests = ( - input_batch.make_lora_inputs(num_scheduled_tokens) + input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens) ) return self._set_active_loras( prompt_lora_mapping, token_lora_mapping, lora_requests @@ -123,8 +128,12 @@ def maybe_select_dummy_loras( self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray, + num_sampled_tokens: np.ndarray | None = None, activate_lora: bool = True, ): + if num_sampled_tokens is None: + num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32) + if lora_config is None: yield else: @@ -143,6 +152,9 @@ def maybe_select_dummy_loras( else: prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32) + # Make sample lora mapping + sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens) + # Make token lora mapping token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) @@ -157,7 +169,7 @@ def maybe_select_dummy_loras( } self._set_active_loras( - tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests + tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests ) yield @@ -167,13 +179,14 @@ def maybe_dummy_run_with_lora( self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray, + num_sampled_tokens: np.ndarray, activate_lora: bool = True, remove_lora: bool = True, ): with ( self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens, activate_lora + lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora ), ): yield diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index d3fb17054c1a..6bf4f9193184 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -526,7 +526,7 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) def make_lora_inputs( - self, num_scheduled_tokens: np.ndarray + self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: """ Given the num_scheduled_tokens for each request in the batch, return