forked from opendatahub-io/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Support LoRA on quantized models (vllm-project#4012)
- Loading branch information
Showing
4 changed files
with
234 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# Adapted from | ||
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py | ||
from dataclasses import dataclass | ||
from typing import List | ||
|
||
import pytest | ||
|
||
import vllm | ||
from vllm.lora.request import LoRARequest | ||
|
||
from .conftest import cleanup | ||
|
||
|
||
@dataclass | ||
class ModelWithQuantization: | ||
model_path: str | ||
quantization: str | ||
|
||
|
||
MODELS: List[ModelWithQuantization] = [ | ||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", | ||
quantization="AWQ"), | ||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", | ||
quantization="GPTQ"), | ||
] | ||
|
||
|
||
def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256): | ||
raw_prompts = [ | ||
"Give me an orange-ish brown color", | ||
"Give me a neon pink color", | ||
] | ||
|
||
def format_prompt_tuples(prompt): | ||
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" | ||
|
||
prompts = [format_prompt_tuples(p) for p in raw_prompts] | ||
|
||
sampling_params = vllm.SamplingParams(temperature=0, | ||
max_tokens=max_tokens, | ||
stop=["<|im_end|>"]) | ||
outputs = llm.generate( | ||
prompts, | ||
sampling_params, | ||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) | ||
if lora_id else None) | ||
# Print the outputs. | ||
generated_texts = [] | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
generated_texts.append(generated_text) | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
return generated_texts | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("tp_size", [1]) | ||
def test_quant_model_lora(tinyllama_lora_files, model, tp_size): | ||
# Cannot use as it will initialize torch.cuda too early... | ||
# if torch.cuda.device_count() < tp_size: | ||
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") | ||
|
||
llm = vllm.LLM(model=model.model_path, | ||
enable_lora=True, | ||
max_num_seqs=16, | ||
max_loras=4, | ||
max_model_len=400, | ||
tensor_parallel_size=tp_size, | ||
quantization=model.quantization, | ||
trust_remote_code=True) | ||
|
||
if model.quantization is None: | ||
expected_no_lora_output = [ | ||
"Here are some examples of orange-brown colors", | ||
"I'm sorry, I don't have" | ||
] | ||
expected_lora_output = [ | ||
"#ff8050", | ||
"#ff8080", | ||
] | ||
elif model.quantization == "AWQ": | ||
expected_no_lora_output = [ | ||
"I'm sorry, I don't understand", | ||
"I'm sorry, I don't understand", | ||
] | ||
expected_lora_output = [ | ||
"#f07700: A v", | ||
"#f00000: A v", | ||
] | ||
elif model.quantization == "GPTQ": | ||
expected_no_lora_output = [ | ||
"I'm sorry, I don't have", | ||
"I'm sorry, I don't have", | ||
] | ||
expected_lora_output = [ | ||
"#f08800: This is", | ||
"#f07788 \n#", | ||
] | ||
|
||
def expect_match(output, expected_output): | ||
# HACK: GPTQ lora outputs are just incredibly unstable. | ||
# Assert that the outputs changed. | ||
if (model.quantization == "GPTQ" | ||
and expected_output is expected_lora_output): | ||
assert output != expected_no_lora_output | ||
for i, o in enumerate(output): | ||
assert o.startswith( | ||
'#'), f"Expected example {i} to start with # but got {o}" | ||
return | ||
assert output == expected_output | ||
|
||
max_tokens = 10 | ||
|
||
print("lora adapter created") | ||
output = do_sample(llm, | ||
tinyllama_lora_files, | ||
lora_id=0, | ||
max_tokens=max_tokens) | ||
expect_match(output, expected_no_lora_output) | ||
|
||
print("lora 1") | ||
output = do_sample(llm, | ||
tinyllama_lora_files, | ||
lora_id=1, | ||
max_tokens=max_tokens) | ||
expect_match(output, expected_lora_output) | ||
|
||
print("no lora") | ||
output = do_sample(llm, | ||
tinyllama_lora_files, | ||
lora_id=0, | ||
max_tokens=max_tokens) | ||
expect_match(output, expected_no_lora_output) | ||
|
||
print("lora 2") | ||
output = do_sample(llm, | ||
tinyllama_lora_files, | ||
lora_id=2, | ||
max_tokens=max_tokens) | ||
expect_match(output, expected_lora_output) | ||
|
||
print("removing lora") | ||
|
||
del llm | ||
cleanup() | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.skip("Requires multiple GPUs") | ||
def test_quant_model_tp_equality(tinyllama_lora_files, model): | ||
# Cannot use as it will initialize torch.cuda too early... | ||
# if torch.cuda.device_count() < 2: | ||
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}") | ||
|
||
llm_tp1 = vllm.LLM(model=model.model_path, | ||
enable_lora=True, | ||
max_num_seqs=16, | ||
max_loras=4, | ||
tensor_parallel_size=1, | ||
quantization=model.quantization, | ||
trust_remote_code=True) | ||
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) | ||
|
||
del llm_tp1 | ||
cleanup() | ||
|
||
llm_tp2 = vllm.LLM(model=model.model_path, | ||
enable_lora=True, | ||
max_num_seqs=16, | ||
max_loras=4, | ||
tensor_parallel_size=2, | ||
quantization=model.quantization) | ||
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) | ||
|
||
del llm_tp2 | ||
cleanup() | ||
|
||
assert output_tp1 == output_tp2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters