diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index caa0e3194f3d..ecab0c8d3256 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -61,6 +61,7 @@ def __init__( ): super().__init__() self.act = act_module + self.input_is_parallel = input_is_parallel if input_is_parallel: tp_size = get_tensor_model_parallel_world_size() intermediate_size_per_partition = divide(intermediate_size, @@ -79,11 +80,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.act(x) / self.scales def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() param_data = param.data - shard_size = param_data.shape[0] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + if self.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + shard_size = param_data.shape[0] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight)