-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Bugfix][LoRA][Spec Decode] Support LoRA with speculative decoding #21068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
202c2cc
e90736f
68673aa
6d96013
1f7fc85
9b98d08
7055e32
5f74f23
735d7eb
f927b9f
643bfd7
605de8a
8e38ddb
9fe779c
7744204
a0609d9
5e23647
0d78c25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| This script contains: | ||
| 1. test lora with speculative decoding for batch inference | ||
| """ | ||
|
|
||
| import random | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.distributed import cleanup_dist_env_and_memory | ||
| from vllm.lora.request import LoRARequest | ||
| from vllm.platforms import current_platform | ||
|
|
||
| LORA_TEST_PROMPT_MAP: dict[str, str] = {} | ||
|
|
||
| LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """ | ||
| ### INSTRUCTION: | ||
| You are an AI assistant that generates Python code to solve linear | ||
| algebra problems. | ||
| ### PROBLEM: | ||
| Find the eigenvalues and eigenvectors of the following 3x3 matrix: | ||
| [[3, 2, 0], | ||
| [2, 3, 0], | ||
| [0, 0, 2]] | ||
| ### OUTPUT FORMAT (STRICT): | ||
| Numbers should be represented as integers only. | ||
| ### PYTHON SOLUTION: | ||
| """ | ||
|
|
||
| SEED = 42 | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") | ||
| @pytest.mark.parametrize( | ||
| "model_setup", | ||
| [ | ||
| ( | ||
| "eagle3", | ||
| "Qwen/Qwen3-1.7B", | ||
| "AngelSlim/Qwen3-1.7B_eagle3", | ||
| "premjatin/qwen-linear-algebra-coder", | ||
| 1, | ||
| ) | ||
| ], | ||
| ) | ||
| def test_batch_inference_correctness( | ||
| monkeypatch: pytest.MonkeyPatch, | ||
| model_setup: tuple[str, str, str, str, int], | ||
| ): | ||
| """ | ||
| Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora. | ||
| Should be the same and no failure when doing batch inference. | ||
| model_setup: (method, model_name, spec_model_name, lora_path, tp_size) | ||
| """ | ||
| with monkeypatch.context() as m: | ||
| m.setenv("VLLM_USE_V1", "1") | ||
|
|
||
| # Disable randomness | ||
| m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8") | ||
| torch.manual_seed(SEED) | ||
| np.random.seed(SEED) | ||
| random.seed(SEED) | ||
| torch.cuda.manual_seed_all(SEED) | ||
| torch.backends.cudnn.benchmark = False | ||
| torch.backends.cudnn.deterministic = True | ||
|
|
||
| method, model_name, spec_model_name, lora_path, tp_size = model_setup | ||
|
|
||
| # without speculative decoding | ||
| ref_llm = LLM( | ||
| model=model_name, | ||
| trust_remote_code=True, | ||
| tensor_parallel_size=tp_size, | ||
| max_model_len=2048, | ||
| max_num_seqs=4, | ||
| enable_lora=True, | ||
| max_loras=1, | ||
| max_cpu_loras=1, | ||
| max_lora_rank=16, | ||
| ) | ||
|
|
||
| prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100 | ||
| lora_request = LoRARequest("adapter", 1, lora_path) | ||
| sampling_params = SamplingParams( | ||
| temperature=0.0, top_p=1.0, top_k=-1, seed=SEED, max_tokens=128 | ||
| ) | ||
|
|
||
| ref_outputs = ref_llm.generate( | ||
| prompts, sampling_params, lora_request=lora_request | ||
| ) | ||
| del ref_llm | ||
| torch.cuda.empty_cache() | ||
| cleanup_dist_env_and_memory() | ||
|
|
||
| lora_spec_llm = LLM( | ||
| model=model_name, | ||
| trust_remote_code=True, | ||
| tensor_parallel_size=tp_size, | ||
| speculative_config={ | ||
| "method": method, | ||
| "model": spec_model_name, | ||
| "num_speculative_tokens": 3, | ||
| "max_model_len": 2048, | ||
| }, | ||
| max_model_len=2048, | ||
| max_num_seqs=4, | ||
| enable_lora=True, | ||
| max_loras=1, | ||
| max_cpu_loras=1, | ||
| max_lora_rank=16, | ||
| ) | ||
|
|
||
| lora_spec_outputs = lora_spec_llm.generate( | ||
| prompts, sampling_params, lora_request=lora_request | ||
| ) | ||
|
|
||
| matches = 0 | ||
| misses = 0 | ||
| for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs): | ||
| if ref_output.outputs[0].text == spec_output.outputs[0].text: | ||
| matches += 1 | ||
| else: | ||
| misses += 1 | ||
| print(f"ref_output: {ref_output.outputs[0].text}") | ||
| print(f"spec_output: {spec_output.outputs[0].text}") | ||
|
|
||
| # Heuristic: expect at least 90% of the prompts to match exactly | ||
| # Upon failure, inspect the outputs to check for inaccuracy. | ||
| print(f"match ratio: {matches}/{len(ref_outputs)}") | ||
| assert matches > int(0.90 * len(ref_outputs)) | ||
| del lora_spec_llm | ||
| torch.cuda.empty_cache() | ||
| cleanup_dist_env_and_memory() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1574,6 +1574,20 @@ def create_engine_config( | |
| else None | ||
| ) | ||
|
|
||
| if ( | ||
| lora_config is not None | ||
| and speculative_config is not None | ||
| and scheduler_config.max_num_batched_tokens | ||
| < ( | ||
| scheduler_config.max_num_seqs | ||
| * (speculative_config.num_speculative_tokens + 1) | ||
| ) | ||
| ): | ||
| raise ValueError( | ||
| "Consider increasing max_num_batched_tokens or " | ||
| "decreasing num_speculative_tokens" | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit : Can you turn it into a ValueError to stay consistent with the error-raising mechanism in this file. Thanks.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
|
||
| # bitsandbytes pre-quantized model need a specific model loader | ||
| if model_config.quantization == "bitsandbytes": | ||
| self.quantization = self.load_format = "bitsandbytes" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,8 +51,12 @@ def __init__( | |
| self.max_loras, max_num_batched_tokens, device=device | ||
| ) | ||
|
|
||
| # When speculative decoding is enabled, max_num_samples is | ||
| # max_batches * (num_speculative_decoding_tokens + 1). | ||
| # This line can be optimized by replacing max_num_batched_tokens | ||
| # to max_batches * (num_speculative_decoding_tokens + 1). | ||
| self.prompt_mapping_meta = LoRAKernelMeta.make( | ||
| self.max_loras, max_batches, device=device | ||
| self.max_loras, max_num_batched_tokens, device=device | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @xiaohongchen1991 . I have seen configurations with
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, should we just use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @varun-sundar-rabindranath . Assert added.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks @li2haipeng ! |
||
| ) | ||
|
|
||
| def update_metadata( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -38,7 +38,6 @@ def load_lora_model( | |||||
| "Regarding multimodal models, vLLM currently " | ||||||
| "only supports adding LoRA to language model." | ||||||
| ) | ||||||
|
|
||||||
| # Add LoRA Manager to the Model Runner | ||||||
| self.lora_manager = LRUCacheWorkerLoRAManager( | ||||||
| vllm_config, | ||||||
|
|
@@ -70,13 +69,19 @@ def _ensure_lora_enabled(self) -> None: | |||||
| raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.") | ||||||
|
|
||||||
| def set_active_loras( | ||||||
| self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray | ||||||
| self, | ||||||
| input_batch: InputBatch, | ||||||
| num_scheduled_tokens: np.ndarray, | ||||||
| num_sampled_tokens: np.ndarray | None = None, | ||||||
| ) -> None: | ||||||
| prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs | ||||||
| if num_sampled_tokens is None: | ||||||
| num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32) | ||||||
|
|
||||||
| prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens) | ||||||
| token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) | ||||||
| lora_requests: set[LoRARequest] | ||||||
| prompt_lora_mapping, token_lora_mapping, lora_requests = ( | ||||||
| input_batch.make_lora_inputs(num_scheduled_tokens) | ||||||
| input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @li2haipeng can you also add an assert after this line like, My main concern is that, given we are doing in vllm/vllm/lora/punica_wrapper/punica_base.py Line 192 in da786e3
Also, I think the interaction between Line 1285 in da786e3
max_num_batched_tokens.
What do you think ?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with Varun
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @varun-sundar-rabindranath Fixed. Thanks for your suggestion. |
||||||
| ) | ||||||
| return self._set_active_loras( | ||||||
| prompt_lora_mapping, token_lora_mapping, lora_requests | ||||||
|
|
@@ -123,8 +128,12 @@ def maybe_select_dummy_loras( | |||||
| self, | ||||||
| lora_config: LoRAConfig | None, | ||||||
| num_scheduled_tokens: np.ndarray, | ||||||
| num_sampled_tokens: np.ndarray | None = None, | ||||||
| activate_lora: bool = True, | ||||||
| ): | ||||||
| if num_sampled_tokens is None: | ||||||
| num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32) | ||||||
|
|
||||||
| if lora_config is None: | ||||||
| yield | ||||||
| else: | ||||||
|
|
@@ -143,6 +152,9 @@ def maybe_select_dummy_loras( | |||||
| else: | ||||||
| prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32) | ||||||
|
|
||||||
| # Make sample lora mapping | ||||||
| sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens) | ||||||
|
|
||||||
| # Make token lora mapping | ||||||
| token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) | ||||||
|
|
||||||
|
|
@@ -157,7 +169,7 @@ def maybe_select_dummy_loras( | |||||
| } | ||||||
|
|
||||||
| self._set_active_loras( | ||||||
| tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests | ||||||
| tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests | ||||||
| ) | ||||||
|
|
||||||
| yield | ||||||
|
|
@@ -167,13 +179,14 @@ def maybe_dummy_run_with_lora( | |||||
| self, | ||||||
| lora_config: LoRAConfig | None, | ||||||
| num_scheduled_tokens: np.ndarray, | ||||||
| num_sampled_tokens: np.ndarray, | ||||||
| activate_lora: bool = True, | ||||||
| remove_lora: bool = True, | ||||||
| ): | ||||||
| with ( | ||||||
| self.maybe_setup_dummy_loras(lora_config, remove_lora), | ||||||
| self.maybe_select_dummy_loras( | ||||||
| lora_config, num_scheduled_tokens, activate_lora | ||||||
| lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora | ||||||
| ), | ||||||
| ): | ||||||
| yield | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given the issues around TP, it'd be good to add a TP = 2 test as well. But it can be a fast-follow after @28318 lands. Thanks.