Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions tests/entrypoints/openai/test_lora_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@
{"r": 1024},
"is greater than max_lora_rank",
),
(
"test_bias",
{"bias": "all"},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
(
"test_modules_to_save",
Expand Down
5 changes: 0 additions & 5 deletions tests/lora/test_peft_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
{"r": 1024},
"is greater than max_lora_rank",
),
(
"test_bias",
{"bias": "all"},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
(
"test_modules_to_save",
Expand Down
15 changes: 2 additions & 13 deletions tests/lora/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class LoRANameParserTestConfig(NamedTuple):
name: str
module_name: str
is_lora_a: bool
is_bias: bool
weights_mapper: Optional[WeightsMapper] = None


Expand All @@ -37,44 +36,37 @@ def test_parse_fine_tuned_lora_name_valid():
"base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens",
True,
False,
),
LoRANameParserTestConfig(
"base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens",
False,
False,
),
LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj",
True,
False,
),
LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj",
False,
False,
),
LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.layers.9.mlp.down_proj",
True,
False,
),
LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.layers.9.mlp.down_proj",
False,
False,
),
# Test with WeightsMapper
LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj",
True,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}
),
Expand All @@ -83,7 +75,6 @@ def test_parse_fine_tuned_lora_name_valid():
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj",
False,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}
),
Expand All @@ -92,7 +83,6 @@ def test_parse_fine_tuned_lora_name_valid():
"model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj",
True,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}
),
Expand All @@ -101,14 +91,13 @@ def test_parse_fine_tuned_lora_name_valid():
"model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj",
False,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}
),
),
]
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name(
for name, module_name, is_lora_a, weights_mapper in fixture:
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(
name, weights_mapper
)

Expand Down
8 changes: 1 addition & 7 deletions vllm/config/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ class LoRAConfig:
per prompt. When run in offline mode, the lora IDs for n modalities
will be automatically assigned to 1-n with the names of the modalities
in alphabetic order."""
bias_enabled: bool = Field(
default=False,
deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.",
)
"""[DEPRECATED] Enable bias for LoRA adapters. This option will be
removed in v0.12.0."""

def compute_hash(self) -> str:
"""
Expand All @@ -96,7 +90,7 @@ def compute_hash(self) -> str:
factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size)
factors.append(self.bias_enabled)

hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

Expand Down
3 changes: 0 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ class EngineArgs:
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
Expand Down Expand Up @@ -916,7 +915,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action=argparse.BooleanOptionalAction,
help="If True, enable handling of LoRA adapters.",
)
lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"])
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
lora_group.add_argument(
Expand Down Expand Up @@ -1515,7 +1513,6 @@ def create_engine_config(

lora_config = (
LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras,
Expand Down
1 change: 0 additions & 1 deletion vllm/lora/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def set_lora(
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
"""Overwrites lora tensors at index."""
...
Expand Down
40 changes: 2 additions & 38 deletions vllm/lora/layers/base_linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, cast
from typing import Optional

import torch
from transformers import PretrainedConfig
Expand Down Expand Up @@ -29,7 +29,6 @@ def __init__(self, base_layer: LinearBase):
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: tuple[int, ...]
self.output_size: int
self.n_slices: int
Expand Down Expand Up @@ -86,38 +85,19 @@ def create_lora_weights(
)
for _ in range(self.n_slices)
)
if lora_config.bias_enabled:
lora_bias_out_size = lora_b_out_size
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_bias_out_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.output_slices = (self.lora_b_stacked[0].shape[2],)

def reset_lora(self, index: int):
for s_index in range(self.n_slices):
self.lora_a_stacked[s_index][index] = 0
self.lora_b_stacked[s_index][index] = 0
if self.lora_config.bias_enabled:
# Make mypy happy
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
self.lora_bias_stacked[s_index][index] = 0

def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
Expand All @@ -131,23 +111,13 @@ def set_lora(
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)

self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True
)
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
if lora_bias is not None:
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, : lora_bias.shape[0]].copy_(
lora_bias, non_blocking=True
)

def apply(
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
Expand All @@ -162,13 +132,7 @@ def apply(
x = x.flatten(0, 1)

lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear(
output,
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.lora_bias_stacked,
1.0,
self.output_slices,
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
)
if not current_platform.can_update_inplace():
output = lora_output
Expand Down
67 changes: 1 addition & 66 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Union, cast
from typing import Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -32,8 +32,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
== len(layer.lora_b_stacked)
== len(layer.output_slices)
)
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)

output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)

Expand Down Expand Up @@ -61,7 +59,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True,
Expand Down Expand Up @@ -122,16 +119,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
lora_b = lora_b[start_idx:end_idx, :]
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
# TODO: Fix the slicing logic of bias.
if bias is None:
return bias
shard_size = self.output_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias

def forward(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
Expand Down Expand Up @@ -238,17 +225,6 @@ def create_lora_weights(
)
for output_size in self.output_slices
)
if lora_config.bias_enabled:
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
output_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for output_size in self.output_slices
)

def slice_lora_a(
self, lora_a: list[Union[torch.Tensor, None]]
Expand All @@ -268,31 +244,18 @@ def slice_lora_b(
]
return sliced_lora_b

def slice_bias(
self, bias: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)
):
if (bias_i := bias[i]) is not None:
bias[i] = bias_i[shard_size * shard_id : shard_size * (shard_id + 1)]
return bias

def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)

if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)

for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None:
Expand All @@ -304,16 +267,6 @@ def set_lora(
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
].copy_(lora_b_i, non_blocking=True)

if lora_bias is not None:
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
for i in range(self.n_slices):
if (lora_bias_i := lora_bias[i]) is not None:
self.lora_bias_stacked[i][index, 0, : lora_bias_i.shape[0]].copy_(
lora_bias_i, non_blocking=True
)

@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
Expand Down Expand Up @@ -380,24 +333,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
bias_q = bias[
self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
* (self.q_shard_id + 1)
]
k_offset = self.q_proj_total_size
bias_k = bias[
k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
]
v_offset = k_offset + self.kv_proj_total_size
bias_v = bias[
v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
]
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
return bias

@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
Expand Down
Loading