From 775e9d03ee293e24e4ad3af0ef8b9fe2d3f08846 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 23 Sep 2025 14:55:51 +0000 Subject: [PATCH 1/2] Init Signed-off-by: Jee Jee Li --- vllm/lora/layers/base_linear.py | 5 +-- vllm/lora/layers/column_parallel_linear.py | 42 ++++++++-------------- vllm/lora/layers/replicated_linear.py | 1 - vllm/lora/layers/row_parallel_linear.py | 15 +++----- 4 files changed, 23 insertions(+), 40 deletions(-) diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 85a1f86ce6bf..0535c546cd44 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -24,11 +24,12 @@ def __init__(self, base_layer: LinearBase): super().__init__() self.base_layer = base_layer self.input_size = self.base_layer.input_size + # Ensure tp_size and tp_rank consistency with the base_layer. + self.tp_size = self.base_layer.tp_size + self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(self.base_layer) self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None - self.output_slices: tuple[int, ...] - self.tp_size: int self.output_size: int self.n_slices: int diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 658fd23165da..6c517288f343 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -8,9 +8,7 @@ from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.utils import divide from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -85,7 +83,6 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None: # inconsistent when TP is greater than 1. self.is_merged_col_linear = type( base_layer) is MergedColumnParallelLinear - self.tp_size = get_tensor_model_parallel_world_size() self.output_size = self.base_layer.output_size_per_partition # There is only one LoRA layer self.n_slices = 1 @@ -97,22 +94,21 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: # Applicable to cases where the base_layer is # MergedColumnParallelLinear. if self.is_merged_col_linear: - tp_rank = get_tensor_model_parallel_rank() shard_size = self.output_size // 2 offset = lora_b.shape[-1] // 2 - left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * - shard_size] - right_weight = lora_b[:, offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size] + left_weight = lora_b[:, self.tp_rank * + shard_size:(self.tp_rank + 1) * shard_size] + right_weight = lora_b[:, + offset + self.tp_rank * shard_size:offset + + (self.tp_rank + 1) * shard_size] lora_b = torch.cat([left_weight, right_weight], dim=1) # Applicable to cases where the base_layer is # ColumnParallelLinear. else: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size lora_b = lora_b[:, start_idx:end_idx] return lora_b @@ -120,10 +116,9 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: # TODO: Fix the slicing logic of bias. if bias is None: return bias - tensor_model_parallel_rank = get_tensor_model_parallel_rank() shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size bias = bias[start_idx:end_idx] return bias @@ -144,7 +139,7 @@ def forward( # Matrix multiply. output_parallel = self.apply(input_, bias) - if self.base_layer.gather_output: + if self.base_layer.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) else: @@ -185,8 +180,6 @@ def __init__( QKVParallelLinear]) -> None: super().__init__(base_layer) # There are two LoRA layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() # the output_sizes in MergedColumnParallelLinear is not sharded by tp # we need to divide it by the tp_size to get correct slices size output_sizes = self.base_layer.output_sizes @@ -342,9 +335,8 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: self.n_slices = 1 def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas lora_b_q = lora_b[:, self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * (self.q_shard_id + 1)] @@ -398,8 +390,6 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: super().__init__(base_layer) # There are three LoRA layer. self.n_slices = len(self.base_layer.output_sizes) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) @@ -462,9 +452,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): # Therefore, the sharding of `lora_a` only needs to correspond with the # gather operation. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size lora_a = lora_a[:, start_idx:start_idx + shard_size] return lora_a @@ -548,9 +537,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): """ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size lora_a = lora_a[:, start_idx:start_idx + shard_size] return lora_a diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py index 3356297c1537..18a8f13ed942 100644 --- a/vllm/lora/layers/replicated_linear.py +++ b/vllm/lora/layers/replicated_linear.py @@ -18,7 +18,6 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): def __init__(self, base_layer: ReplicatedLinear) -> None: super().__init__(base_layer, ) # To ensure interface compatibility, set to 1 always. - self.tp_size = 1 self.output_size = self.base_layer.output_size self.n_slices = 1 diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index 18ef6fd1ddd7..d1640174d9d2 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -8,9 +8,7 @@ from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, +from vllm.distributed import (split_tensor_along_last_dim, tensor_model_parallel_all_reduce) # yapf: disable from vllm.model_executor.layers.linear import RowParallelLinear @@ -25,12 +23,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__(base_layer) - self.tp_size = get_tensor_model_parallel_world_size() # reset input_size self.input_size = self.base_layer.input_size_per_partition self.output_size = self.base_layer.output_size - - self.tp_rank = get_tensor_model_parallel_rank() # There is only one LoRA layer. self.n_slices = 1 @@ -68,12 +63,12 @@ def forward( else: # TODO: simplify code below splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.base_layer.tp_size) + input_, num_partitions=self.tp_size) input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. output_parallel = self.apply(input_parallel) - if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + if self.base_layer.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: output_ = output_parallel @@ -154,8 +149,8 @@ def apply(self, buffer, x, self.lora_a_stacked, 1.0) if not current_platform.can_update_inplace(): buffer = shrunk_buffer - - buffer = tensor_model_parallel_all_reduce(buffer) + if self.tp_size>1: + buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce # by adding the column partitioned lora output to a slice of output From 05e855946ee8114fc3a7d3460c78a20a79319dd8 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 23 Sep 2025 16:17:19 +0000 Subject: [PATCH 2/2] Fix format Signed-off-by: Jee Jee Li --- vllm/lora/layers/column_parallel_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 4afca37f376b..6284576446c8 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -333,7 +333,7 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: self.n_slices = 1 def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - + self.q_shard_id = self.tp_rank self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas lora_b_q = lora_b[self.q_proj_shard_size *