diff --git a/tests/entrypoints/openai/test_lora_adapters.py b/tests/entrypoints/openai/test_lora_adapters.py index 674e14e4f5c1..c74f805961bc 100644 --- a/tests/entrypoints/openai/test_lora_adapters.py +++ b/tests/entrypoints/openai/test_lora_adapters.py @@ -23,11 +23,6 @@ {"r": 1024}, "is greater than max_lora_rank", ), - ( - "test_bias", - {"bias": "all"}, - "Adapter bias cannot be used without bias_enabled", - ), ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py index 2cc8bfe63495..9c55c623d444 100644 --- a/tests/lora/test_peft_helper.py +++ b/tests/lora/test_peft_helper.py @@ -16,11 +16,6 @@ {"r": 1024}, "is greater than max_lora_rank", ), - ( - "test_bias", - {"bias": "all"}, - "Adapter bias cannot be used without bias_enabled", - ), ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index aed91d98ddbd..c861a52d6872 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -21,7 +21,6 @@ class LoRANameParserTestConfig(NamedTuple): name: str module_name: str is_lora_a: bool - is_bias: bool weights_mapper: Optional[WeightsMapper] = None @@ -37,44 +36,37 @@ def test_parse_fine_tuned_lora_name_valid(): "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", True, - False, ), LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_B", "model.embed_tokens", False, - False, ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj", True, - False, ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj", False, - False, ), LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj", True, - False, ), LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj", False, - False, ), # Test with WeightsMapper LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "language_model.model.layers.9.mlp.down_proj", True, - False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."} ), @@ -83,7 +75,6 @@ def test_parse_fine_tuned_lora_name_valid(): "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "language_model.model.layers.9.mlp.down_proj", False, - False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."} ), @@ -92,7 +83,6 @@ def test_parse_fine_tuned_lora_name_valid(): "model.layers.9.mlp.down_proj.lora_A.weight", "language_model.model.layers.9.mlp.down_proj", True, - False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."} ), @@ -101,14 +91,13 @@ def test_parse_fine_tuned_lora_name_valid(): "model.layers.9.mlp.down_proj.lora_B.weight", "language_model.model.layers.9.mlp.down_proj", False, - False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."} ), ), ] - for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: - assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name( + for name, module_name, is_lora_a, weights_mapper in fixture: + assert (module_name, is_lora_a) == parse_fine_tuned_lora_name( name, weights_mapper ) diff --git a/vllm/config/lora.py b/vllm/config/lora.py index 46f41b9e448b..c531618a186d 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -70,12 +70,6 @@ class LoRAConfig: per prompt. When run in offline mode, the lora IDs for n modalities will be automatically assigned to 1-n with the names of the modalities in alphabetic order.""" - bias_enabled: bool = Field( - default=False, - deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.", - ) - """[DEPRECATED] Enable bias for LoRA adapters. This option will be - removed in v0.12.0.""" def compute_hash(self) -> str: """ @@ -96,7 +90,7 @@ def compute_hash(self) -> str: factors.append(self.lora_dtype) factors.append(self.lora_extra_vocab_size) factors.append(self.lora_vocab_padding_size) - factors.append(self.bias_enabled) + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7e66d8dba8ac..cb47e439fc73 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -439,7 +439,6 @@ class EngineArgs: video_pruning_rate: float = MultiModalConfig.video_pruning_rate # LoRA fields enable_lora: bool = False - enable_lora_bias: bool = LoRAConfig.bias_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras @@ -916,7 +915,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action=argparse.BooleanOptionalAction, help="If True, enable handling of LoRA adapters.", ) - lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) lora_group.add_argument( @@ -1515,7 +1513,6 @@ def create_engine_config( lora_config = ( LoRAConfig( - bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, default_mm_loras=self.default_mm_loras, diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py index 753dc268a2ff..5279247a1759 100644 --- a/vllm/lora/layers/base.py +++ b/vllm/lora/layers/base.py @@ -45,7 +45,6 @@ def set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, ): """Overwrites lora tensors at index.""" ... diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index d2f017c19ccd..da053f0923ab 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, cast +from typing import Optional import torch from transformers import PretrainedConfig @@ -29,7 +29,6 @@ def __init__(self, base_layer: LinearBase): self.tp_size = self.base_layer.tp_size self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(self.base_layer) - self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None self.output_slices: tuple[int, ...] self.output_size: int self.n_slices: int @@ -86,30 +85,12 @@ def create_lora_weights( ) for _ in range(self.n_slices) ) - if lora_config.bias_enabled: - lora_bias_out_size = lora_b_out_size - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_bias_out_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - for _ in range(self.n_slices) - ) self.output_slices = (self.lora_b_stacked[0].shape[2],) def reset_lora(self, index: int): for s_index in range(self.n_slices): self.lora_a_stacked[s_index][index] = 0 self.lora_b_stacked[s_index][index] = 0 - if self.lora_config.bias_enabled: - # Make mypy happy - self.lora_bias_stacked = cast( - tuple[torch.Tensor, ...], self.lora_bias_stacked - ) - self.lora_bias_stacked[s_index][index] = 0 def set_lora( self, @@ -117,7 +98,6 @@ def set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, ): # Except for QKVParallelLinearWithLoRA and # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers @@ -131,8 +111,6 @@ def set_lora( if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( lora_a, non_blocking=True @@ -140,14 +118,6 @@ def set_lora( self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( lora_b, non_blocking=True ) - if lora_bias is not None: - self.lora_bias_stacked = cast( - tuple[torch.Tensor, ...], self.lora_bias_stacked - ) - assert len(self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, : lora_bias.shape[0]].copy_( - lora_bias, non_blocking=True - ) def apply( self, x: torch.Tensor, bias: Optional[torch.Tensor] = None @@ -162,13 +132,7 @@ def apply( x = x.flatten(0, 1) lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear( - output, - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, - 1.0, - self.output_slices, + output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices ) if not current_platform.can_update_inplace(): output = lora_output diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 011d38157456..c49b90a80cea 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union, cast +from typing import Optional, Union import torch import torch.nn as nn @@ -32,8 +32,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): == len(layer.lora_b_stacked) == len(layer.output_slices) ) - if layer.lora_bias_stacked is not None: - assert layer.n_slices == len(layer.lora_bias_stacked) output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) @@ -61,7 +59,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): output, buffers, layer.lora_b_stacked, - layer.lora_bias_stacked, layer.output_slices, offset_start=0, add_input=True, @@ -122,16 +119,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: lora_b = lora_b[start_idx:end_idx, :] return lora_b - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - # TODO: Fix the slicing logic of bias. - if bias is None: - return bias - shard_size = self.output_size - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - def forward( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: @@ -238,17 +225,6 @@ def create_lora_weights( ) for output_size in self.output_slices ) - if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - for output_size in self.output_slices - ) def slice_lora_a( self, lora_a: list[Union[torch.Tensor, None]] @@ -268,31 +244,18 @@ def slice_lora_b( ] return sliced_lora_b - def slice_bias( - self, bias: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices) - ): - if (bias_i := bias[i]) is not None: - bias[i] = bias_i[shard_size * shard_id : shard_size * (shard_id + 1)] - return bias - def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) for i in range(self.n_slices): if (lora_a_i := lora_a[i]) is not None: @@ -304,16 +267,6 @@ def set_lora( index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1] ].copy_(lora_b_i, non_blocking=True) - if lora_bias is not None: - self.lora_bias_stacked = cast( - tuple[torch.Tensor, ...], self.lora_bias_stacked - ) - for i in range(self.n_slices): - if (lora_bias_i := lora_bias[i]) is not None: - self.lora_bias_stacked[i][index, 0, : lora_bias_i.shape[0]].copy_( - lora_bias_i, non_blocking=True - ) - @classmethod @_not_fully_sharded_can_replace def can_replace_layer( @@ -380,24 +333,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0) return lora_b - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - bias_q = bias[ - self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size - * (self.q_shard_id + 1) - ] - k_offset = self.q_proj_total_size - bias_k = bias[ - k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset - + self.kv_proj_shard_size * (self.kv_shard_id + 1) - ] - v_offset = k_offset + self.kv_proj_total_size - bias_v = bias[ - v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset - + self.kv_proj_shard_size * (self.kv_shard_id + 1) - ] - bias = torch.cat([bias_q, bias_k, bias_v], dim=1) - return bias - @classmethod @_not_fully_sharded_can_replace def can_replace_layer( diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index 4f30c9db4c67..f3ca60fb28d9 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -143,7 +143,6 @@ def set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index 738371f22a36..fff4fb38ead9 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union, cast +from typing import Optional, Union import torch import torch.nn as nn @@ -39,9 +39,6 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: return lora_b - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias - def forward( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: @@ -123,16 +120,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: lora_b = lora_b[start_idx:end_idx, :] return lora_b - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - if bias is None: - return bias - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) - shard_size = self.lora_bias_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - def apply( self, x: torch.Tensor, bias: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -167,7 +154,6 @@ def apply( output, buffer, self.lora_b_stacked, - self.lora_bias_stacked, self.output_slices, offset_start=offset_start, add_input=True, diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index 42eae1d4e3b0..0a252b425c4a 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -91,7 +91,6 @@ def set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) # NOTE self.lora_a_stacked is row-major, and lora_a is col-major, diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py index d502c8eb543f..b043a46f9e2a 100644 --- a/vllm/lora/lora_weights.py +++ b/vllm/lora/lora_weights.py @@ -21,7 +21,6 @@ def __init__( lora_alpha: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - bias: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None, scaling: Optional[float] = None, ) -> None: @@ -30,7 +29,6 @@ def __init__( self.lora_alpha = lora_alpha self.lora_a = lora_a self.lora_b = lora_b - self.bias = bias self.embeddings_tensor = embeddings_tensor if scaling is None: @@ -71,13 +69,13 @@ def from_config( peft_helper: PEFTHelper, embeddings_tensor: Optional[torch.Tensor] = None, ) -> "LoRALayerWeights": + # lora_a and lora_b are set to None for config-based construction return cls( module_name, peft_helper.r, peft_helper.lora_alpha, None, None, - None, embeddings_tensor, peft_helper.vllm_lora_scaling_factor, ) @@ -92,7 +90,6 @@ def create_dummy_lora_weights( dtype: torch.dtype, device: torch.types.Device, embeddings_tensor_dim: Optional[int] = None, - bias_enabled: Optional[bool] = False, ) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() lora_a = torch.zeros( @@ -101,12 +98,6 @@ def create_dummy_lora_weights( lora_b = torch.zeros( [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory ) - if bias_enabled: - bias = torch.zeros( - [output_dim], dtype=dtype, device=device, pin_memory=pin_memory - ) - else: - bias = None embeddings_tensor = ( torch.rand( @@ -125,7 +116,6 @@ def create_dummy_lora_weights( lora_alpha=1, lora_a=lora_a, lora_b=lora_b, - bias=bias, embeddings_tensor=embeddings_tensor, ) @@ -140,7 +130,6 @@ def __init__( lora_alphas: list[Optional[int]], lora_a: list[Optional[torch.Tensor]], lora_b: list[Optional[torch.Tensor]], - bias: Optional[list[Optional[torch.Tensor]]] = None, scaling: Optional[list[float]] = None, ) -> None: super().__init__( @@ -149,7 +138,6 @@ def __init__( lora_alpha=0, lora_a=lora_a, lora_b=lora_b, - bias=bias, scaling=scaling, # type: ignore embeddings_tensor=None, ) @@ -181,7 +169,6 @@ def pack( [lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras], - [lora.bias if lora is not None else None for lora in loras], scaling=[ 1 if lora is not None else None # type: ignore for lora in loras diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 771c8608f4a8..cf9089eff175 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,7 +3,6 @@ import math import os -from collections.abc import Sequence from typing import Callable, Optional, TypeVar, Union import regex as re @@ -140,7 +139,7 @@ def from_lora_tensors( pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): - module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( + module_name, is_lora_a = parse_fine_tuned_lora_name( tensor_name, weights_mapper ) if module_name not in loras: @@ -160,13 +159,7 @@ def from_lora_tensors( module_name, peft_helper, lora_embeddings_tensor ) - if is_bias: - loras[module_name].bias = tensor.to(device=device, dtype=dtype) - bias = tensor.to(device=device, dtype=dtype) - if pin_memory: - bias = bias.pin_memory() - loras[module_name].bias = bias - elif is_lora_a: + if is_lora_a: loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) if pin_memory: loras[module_name].lora_a = loras[module_name].lora_a.pin_memory() @@ -234,9 +227,7 @@ def from_local_checkpoint( def check_unexpected_modules(modules: dict): for lora_module in modules.keys(): # noqa - module_name, _, _ = parse_fine_tuned_lora_name( - lora_module, weights_mapper - ) + module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) part_name = module_name.split(".")[-1] if part_name not in expected_lora_modules: unexpected_modules.append(module_name) @@ -439,23 +430,11 @@ def activate_adapter( module_lora = self._get_lora_layer_weights(lora_model, module_name) if module_lora: module_lora.optimize() - # Bias is not explicitly enabled with the flag enable_lora_bias. - bias = module_lora.bias - if ( - torch.is_tensor(bias) - or (isinstance(bias, Sequence) and any(b is not None for b in bias)) - ) and not self.lora_config.bias_enabled: - module_lora.bias = None - raise ValueError( - f"Adapter bias cannot be used for {module_name}" - " without --enable-lora-bias." - ) module.set_lora( index, module_lora.lora_a, module_lora.lora_b, module_lora.embeddings_tensor, - module_lora.bias, ) else: module.reset_lora(index) @@ -581,7 +560,6 @@ def create_dummy_lora( """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): - bias_enabled = self.lora_config.bias_enabled if ( not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) @@ -616,7 +594,6 @@ def create_dummy_lora( module.lora_a_stacked[0].dtype, "cpu", embeddings_tensor_dim=embeddings_tensor_dim, - bias_enabled=bias_enabled, ) else: lora = LoRALayerWeights.create_dummy_lora_weights( @@ -626,7 +603,6 @@ def create_dummy_lora( rank, module.lora_a_stacked[0].dtype, "cpu", - bias_enabled=bias_enabled, ) else: parts = module_name.split(".") @@ -640,7 +616,6 @@ def create_dummy_lora( rank, module.lora_a_stacked[i].dtype, "cpu", - bias_enabled=bias_enabled, ) subloras.append(lora) lora = PackedLoRALayerWeights.pack(subloras) diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 48412eab92d8..8f21a2570224 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -29,7 +29,7 @@ class PEFTHelper: lora_alpha: int target_modules: Union[list[str], str] - bias: Literal["none", "all", "lora_only"] = field(default="none") + bias: Literal["none"] = field(default="none") modules_to_save: Optional[list[str]] = field(default=None) # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) use_rslora: bool = field(default=False) @@ -122,7 +122,7 @@ def validate_legal(self, lora_config: LoRAConfig) -> None: f"LoRA rank {self.r} is greater than max_lora_rank" f" {lora_config.max_lora_rank}." ) - if self.bias != "none" and not lora_config.bias_enabled: - error_msg.append("Adapter bias cannot be used without bias_enabled.") + if self.bias != "none": + error_msg.append("Adapter bias is not supported.") if error_msg: raise ValueError(f"{' '.join(error_msg)}") diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 770c3cf7b073..b803a482b1bc 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -60,14 +60,13 @@ def add_expand( y: torch.Tensor, x: Union[tuple[torch.Tensor, ...], torch.Tensor], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> Optional[torch.Tensor]: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. """ raise NotImplementedError @@ -93,7 +92,6 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, @@ -222,38 +220,6 @@ def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: self.token_nums = token_nums self.no_lora = no_lora - def _apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - output_slices: tuple[int, ...], - lora_bias_stacked: tuple[Optional[torch.Tensor], ...], - ): - """Applies bias to output - - Input shapes: - lora_bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = lora_bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias[indices == -1] = 0 - output[:, offset_left : offset_left + slice] += bias - offset_left += slice - - return output.view_as(org_output) - @property def prefill_metadata( self, @@ -365,29 +331,25 @@ def add_expand( y: torch.Tensor, x: Union[tuple[torch.Tensor, ...], torch.Tensor], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> Optional[torch.Tensor]: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: offset = offset_start for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size offset_start (int): The starting position of y, defaults to 0 add_inputs (bool): Defaults to True. @@ -427,7 +389,6 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, @@ -444,14 +405,13 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index c51a13db873c..93e64eb6ba84 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -199,38 +199,30 @@ def add_expand( y: torch.Tensor, x: Union[tuple[torch.Tensor, ...], torch.Tensor], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> None: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) offset_left = offset_start - if lora_bias_stacked is not None: - self._apply_bias( - self.token_lora_indices, y, output_slices, lora_bias_stacked - ) for slice_idx in range(len(lora_b_stacked)): self._apply_expand( y, @@ -276,7 +268,6 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, @@ -293,25 +284,19 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias( - self.token_lora_indices, y, output_slices, lora_bias_stacked - ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -323,7 +308,7 @@ def add_lora_linear( ) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_expand( - y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs + y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs ) def add_lora_logits( diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 431e97102faf..8173fe99ea13 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -101,36 +101,29 @@ def add_expand( y: torch.Tensor, x: torch.Tensor, lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> None: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) - if lora_bias_stacked is not None: - token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0)) - self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked) assert x.ndim == 3 assert x.size(0) == len(output_slices) @@ -183,7 +176,6 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, @@ -200,26 +192,18 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] - + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[torch.Tensor]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0)) - y = self._apply_bias( - token_lora_indices, y, output_slices, lora_bias_stacked - ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -241,7 +225,6 @@ def add_lora_linear( y, buffer, # type: ignore lora_b_stacked, - None, output_slices, add_inputs=True, **kwargs, diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 5d2f05b815be..dff30d5d2a2d 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -139,28 +139,24 @@ def add_expand( y: torch.Tensor, x: Union[tuple[torch.Tensor, ...], torch.Tensor], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> torch.Tensor: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ @@ -168,10 +164,6 @@ def add_expand( y = y.view(-1, y.shape[-1]) offset_left = 0 - if lora_bias_stacked is not None: - y = self._apply_bias( - self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked - ) for slice_idx in range(len(lora_b_stacked)): y = self.expand_slice( y, @@ -214,7 +206,6 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, @@ -231,25 +222,19 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will not be changed in-place. x (torch.Tensor): Input tensor (T, E) lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias( - self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked - ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -261,7 +246,7 @@ def add_lora_linear( ) buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) return self.add_expand( - y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs + y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs ) def add_lora_logits( @@ -299,43 +284,6 @@ def add_lora_logits( y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) return y.view_as(y_org) - def _apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - output_slices: tuple[int, ...], - lora_bias_stacked: tuple[Optional[torch.Tensor], ...], - ): - """Applies bias to output - - Input shapes: - lora_bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = lora_bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias = torch.where(indices[:, None] == -1, 0, bias) - - bias = F.pad( - bias, (offset_left, output.shape[1] - (offset_left + slice), 0, 0) - ) - - output += bias - offset_left += slice - - return output.view_as(org_output) - # This performs the same tensor ops as the base method, except it does them # on the CPU then transfers the results to the TPU def _update_base_metadata( diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 5196199b2ac3..e3d03ac8dc2c 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -108,36 +108,29 @@ def add_expand( y: torch.Tensor, x: torch.Tensor, lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> None: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) - if lora_bias_stacked is not None: - token_lora_indices = self._get_token_lora_indices(y) - self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked) assert x.ndim == 3 assert x.size(0) == len(output_slices) @@ -184,7 +177,6 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, @@ -201,26 +193,19 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[torch.Tensor]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - token_lora_indices = self._get_token_lora_indices(y) - y = self._apply_bias( - token_lora_indices, y, output_slices, lora_bias_stacked - ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -242,7 +227,6 @@ def add_lora_linear( y, buffer, # type: ignore lora_b_stacked, - None, output_slices, add_inputs=True, **kwargs, diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 5e55d44ce8d9..595c774e03be 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -112,7 +112,7 @@ def replace_submodule( def parse_fine_tuned_lora_name( name: str, weights_mapper: Optional["WeightsMapper"] = None -) -> tuple[str, bool, bool]: +) -> tuple[str, bool]: """Parse the name of lora weights. args: @@ -124,7 +124,6 @@ def parse_fine_tuned_lora_name( tuple(module_name, is_lora_a): module_name: the name of the module, e.g. model.dense1, is_lora_a whether the tensor is lora_a or lora_b. - is_bias whether the tensor is lora bias. """ # LoRA weight qualified name usually starts with `base_model.model.`, @@ -146,15 +145,11 @@ def parse_fine_tuned_lora_name( parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): new_name = ".".join(parts[start_index:-2]) - return new_name, parts[-2] == "lora_A", False + return new_name, parts[-2] == "lora_A" if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": new_name = ".".join(parts[start_index:-1]) - return new_name, parts[-1] == "lora_embedding_A", False - - if parts[-1] == "bias": - new_name = ".".join(parts[start_index:-2]) - return new_name, False, True + return new_name, parts[-1] == "lora_embedding_A" raise ValueError(f"{name} is unsupported LoRA weight")