diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index d28c97116790..88c551d38644 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -889,7 +889,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): return loaded_params -class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): +class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -931,6 +931,56 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + +class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + # Set MoE hyperparameters self.expert_weights = [] self.num_expert_groups = 1 @@ -989,57 +1039,19 @@ def update_physical_experts_metadata( moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return model_output - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits +class HunYuanDenseV1Base(HunyuanV1ModelBase): - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - ) - return loader.load_weights(weights) - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) -class HunYuanDenseV1ForCausalLM(HunYuanV1Base): +class HunYuanDenseV1ForCausalLM(HunYuanDenseV1Base): pass -class HunYuanMoEV1ForCausalLM(HunYuanV1Base): - pass +class HunYuanMoEV1ForCausalLM(HunYuanMoEV1Base): + pass \ No newline at end of file