Skip to content

Commit

Permalink
[Core] Support LoRA on quantized models (vllm-project#4012)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored and joerunde committed Apr 18, 2024
1 parent c65f7b6 commit f7b4d44
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 26 deletions.
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def baichuan_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")


@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
Expand Down
179 changes: 179 additions & 0 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Adapted from
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
from dataclasses import dataclass
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

from .conftest import cleanup


@dataclass
class ModelWithQuantization:
model_path: str
quantization: str


MODELS: List[ModelWithQuantization] = [
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
quantization="AWQ"),
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
quantization="GPTQ"),
]


def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
raw_prompts = [
"Give me an orange-ish brown color",
"Give me a neon pink color",
]

def format_prompt_tuples(prompt):
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

prompts = [format_prompt_tuples(p) for p in raw_prompts]

sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=max_tokens,
stop=["<|im_end|>"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

llm = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_model_len=400,
tensor_parallel_size=tp_size,
quantization=model.quantization,
trust_remote_code=True)

if model.quantization is None:
expected_no_lora_output = [
"Here are some examples of orange-brown colors",
"I'm sorry, I don't have"
]
expected_lora_output = [
"#ff8050",
"#ff8080",
]
elif model.quantization == "AWQ":
expected_no_lora_output = [
"I'm sorry, I don't understand",
"I'm sorry, I don't understand",
]
expected_lora_output = [
"#f07700: A v",
"#f00000: A v",
]
elif model.quantization == "GPTQ":
expected_no_lora_output = [
"I'm sorry, I don't have",
"I'm sorry, I don't have",
]
expected_lora_output = [
"#f08800: This is",
"#f07788 \n#",
]

def expect_match(output, expected_output):
# HACK: GPTQ lora outputs are just incredibly unstable.
# Assert that the outputs changed.
if (model.quantization == "GPTQ"
and expected_output is expected_lora_output):
assert output != expected_no_lora_output
for i, o in enumerate(output):
assert o.startswith(
'#'), f"Expected example {i} to start with # but got {o}"
return
assert output == expected_output

max_tokens = 10

print("lora adapter created")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)

print("lora 1")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=1,
max_tokens=max_tokens)
expect_match(output, expected_lora_output)

print("no lora")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)

print("lora 2")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=2,
max_tokens=max_tokens)
expect_match(output, expected_lora_output)

print("removing lora")

del llm
cleanup()


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.skip("Requires multiple GPUs")
def test_quant_model_tp_equality(tinyllama_lora_files, model):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 2:
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")

llm_tp1 = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1,
quantization=model.quantization,
trust_remote_code=True)
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)

del llm_tp1
cleanup()

llm_tp2 = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=2,
quantization=model.quantization)
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)

del llm_tp2
cleanup()

assert output_tp1 == output_tp2
9 changes: 6 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,12 @@ def verify_with_model_config(self, model_config: ModelConfig):
self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
if model_config.quantization is not None:
raise ValueError(
"LoRA is not supported with quantized models yet.")
if model_config.quantization and model_config.quantization not in [
"awq", "gptq"
]:
# TODO support marlin and squeezellm
logger.warning(f"{model_config.quantization} quantization is not "
"tested with LoRA yet.")

def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
Expand Down
67 changes: 44 additions & 23 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@
pass


def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
if hasattr(base_layer, "weight"):
return base_layer.weight.device
if hasattr(base_layer, "linear_weights") and isinstance(
base_layer.linear_weights, dict):
values = list(base_layer.linear_weights.values())
if len(values) and isinstance(values[0], torch.Tensor):
return values[0].device
raise ValueError(f"Unsupported base layer: {base_layer}")


def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
Expand Down Expand Up @@ -302,6 +315,9 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size = self.base_layer.input_size
self.output_size = self.base_layer.output_size_per_partition
self.device = _get_lora_device(self.base_layer)

def create_lora_weights(
self,
Expand All @@ -312,17 +328,17 @@ def create_lora_weights(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
self.input_size,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
max_loras,
1,
self.base_layer.weight.shape[0],
self.output_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
)

self.indices: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -442,18 +458,18 @@ def create_lora_weights(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
self.input_size,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
) for _ in range(n_slices))
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
self.base_layer.weight.shape[0] // 2,
self.output_size // 2,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
) for _ in range(n_slices))

self.indices: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -619,25 +635,25 @@ def create_lora_weights(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
self.input_size,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
self.input_size,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
self.input_size,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
),
)
self.lora_b_stacked = (
Expand All @@ -647,23 +663,23 @@ def create_lora_weights(
self.q_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
),
)

Expand Down Expand Up @@ -766,6 +782,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size
self.device = _get_lora_device(self.base_layer)

def create_lora_weights(
self,
Expand All @@ -777,20 +796,20 @@ def create_lora_weights(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
self.input_size,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.weight.shape[0],
self.output_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
Expand All @@ -809,7 +828,7 @@ def set_lora(
self.reset_lora(index)
if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.base_layer.weight.shape[1]
shard_size = self.input_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
Expand Down Expand Up @@ -884,7 +903,9 @@ def forward(self, input_):

@property
def weight(self):
return self.base_layer.weight

return self.base_layer.weight if hasattr(
self.base_layer, "weight") else self.base_layer.qweight

@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
Expand Down

0 comments on commit f7b4d44

Please sign in to comment.