From 511dc43d030f60a978c2a9a1ddb057e1c15df518 Mon Sep 17 00:00:00 2001 From: mgoin Date: Sun, 17 Nov 2024 21:06:00 +0000 Subject: [PATCH] [Model] Support TP for PixtralHF ViT Signed-off-by: mgoin --- vllm/model_executor/models/pixtral.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6bd5e119dd2d..ee5ef4f4ff99 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -17,6 +17,7 @@ from vllm.attention import AttentionMetadata from vllm.config import ModelConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_and_mul_fn @@ -830,17 +831,20 @@ def __init__( self.config = config assert not config.hidden_size % config.num_attention_heads - self.n_heads = config.num_attention_heads + self.total_num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + self.n_heads = divide(config.num_attention_heads, tp_size) self.head_dim = config.hidden_size // config.num_attention_heads self.qkv_proj = QKVParallelLinear( hidden_size=config.hidden_size, head_size=self.head_dim, - total_num_heads=self.n_heads, + total_num_heads=self.total_num_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) + assert self.total_num_heads * self.head_dim == config.hidden_size self.o_proj = RowParallelLinear( input_size=config.hidden_size, output_size=config.hidden_size,