diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 3bd8677b0628..0e6e71ec724b 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -49,10 +49,11 @@ def __init__(self, config: GPTBigCodeConfig): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - tensor_model_parallel_world_size = ( + self.tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) - assert total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = total_num_heads // tensor_model_parallel_world_size + assert total_num_heads % self.tensor_model_parallel_world_size == 0 + self.num_heads = (total_num_heads // + self.tensor_model_parallel_world_size) self.head_dim = self.hidden_size // total_num_heads self.scale = self.head_dim**-0.5 @@ -101,7 +102,10 @@ def forward( k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1) else: qkv, _ = self.c_attn(hidden_states) - q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], + q, k, v = qkv.split([ + self.hidden_size // self.tensor_model_parallel_world_size, + self.kv_dim, self.kv_dim + ], dim=-1) key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, key_cache, value_cache, @@ -255,8 +259,6 @@ def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() @@ -286,7 +288,8 @@ def load_weights(self, hidden_size = self.config.hidden_size head_size = hidden_size // total_num_heads total_kv_size = head_size * total_num_kv_heads - num_heads = total_num_heads // tensor_model_parallel_world_size + num_heads = (total_num_heads // + self.tensor_model_parallel_world_size) head_start = tensor_model_parallel_rank * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads @@ -326,7 +329,7 @@ def load_weights(self, if name == "transformer.wte.weight": # Consider padding in the vocab size. padded_vocab_size = param.shape[ - 0] * tensor_model_parallel_world_size + 0] * self.tensor_model_parallel_world_size num_extra_rows = padded_vocab_size - self.config.vocab_size extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])