Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm/lora/layers/base_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ def __init__(self, base_layer: LinearBase):
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
# Ensure tp_size and tp_rank consistency with the base_layer.
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None

self.output_slices: tuple[int, ...]
self.tp_size: int
self.output_size: int
self.n_slices: int

Expand Down
40 changes: 14 additions & 26 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from transformers import PretrainedConfig

from vllm.config.lora import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.utils import divide
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -85,7 +83,6 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None:
# inconsistent when TP is greater than 1.
self.is_merged_col_linear = type(
base_layer) is MergedColumnParallelLinear
self.tp_size = get_tensor_model_parallel_world_size()
self.output_size = self.base_layer.output_size_per_partition
# There is only one LoRA layer
self.n_slices = 1
Expand All @@ -97,33 +94,30 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
# Applicable to cases where the base_layer is
# MergedColumnParallelLinear.
if self.is_merged_col_linear:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size // 2
offset = lora_b.shape[0] // 2

left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
left_weight = lora_b[self.tp_rank * shard_size:(self.tp_rank + 1) *
shard_size, :]
right_weight = lora_b[offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size, :]
right_weight = lora_b[offset + self.tp_rank * shard_size:offset +
(self.tp_rank + 1) * shard_size, :]
lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is
# ColumnParallelLinear.
else:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[start_idx:end_idx, :]
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
# TODO: Fix the slicing logic of bias.
if bias is None:
return bias
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias

Expand All @@ -144,7 +138,7 @@ def forward(

# Matrix multiply.
output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output:
if self.base_layer.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
Expand Down Expand Up @@ -185,8 +179,6 @@ def __init__(
QKVParallelLinear]) -> None:
super().__init__(base_layer)
# There are two LoRA layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size
output_sizes = self.base_layer.output_sizes
Expand Down Expand Up @@ -341,9 +333,9 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
self.n_slices = 1

def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas

self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1), :]
Expand Down Expand Up @@ -397,8 +389,6 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
# There are three LoRA layer.
self.n_slices = len(self.base_layer.output_sizes)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()

self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
Expand Down Expand Up @@ -461,9 +451,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
# Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a

Expand Down Expand Up @@ -547,9 +536,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
"""

def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a

Expand Down
1 change: 0 additions & 1 deletion vllm/lora/layers/replicated_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__(base_layer, )
# To ensure interface compatibility, set to 1 always.
self.tp_size = 1
self.output_size = self.base_layer.output_size
self.n_slices = 1

Expand Down
15 changes: 5 additions & 10 deletions vllm/lora/layers/row_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from transformers import PretrainedConfig

from vllm.config.lora import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
from vllm.distributed import (split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
# yapf: disable
from vllm.model_executor.layers.linear import RowParallelLinear
Expand All @@ -25,12 +23,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__(base_layer)

self.tp_size = get_tensor_model_parallel_world_size()
# reset input_size
self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size

self.tp_rank = get_tensor_model_parallel_rank()
# There is only one LoRA layer.
self.n_slices = 1

Expand Down Expand Up @@ -68,12 +63,12 @@ def forward(
else:
# TODO: simplify code below
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size)
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()

# Matrix multiply.
output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
if self.base_layer.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
Expand Down Expand Up @@ -154,8 +149,8 @@ def apply(self,
buffer, x, self.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace():
buffer = shrunk_buffer

buffer = tensor_model_parallel_all_reduce(buffer)
if self.tp_size>1:
buffer = tensor_model_parallel_all_reduce(buffer)

# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def optimize(self) -> "LoRALayerWeights":

@property
def input_dim(self) -> int:
return self.lora_a.shape[0]
return self.lora_a.shape[1]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, fix the dimension mismatch


@property
def output_dim(self) -> int:
return self.lora_b.shape[1]
return self.lora_b.shape[0]

@property
def is_packed(self) -> bool:
Expand Down