From 6eccf653f38bfd7b6a57534711faa4ffa75f1bb3 Mon Sep 17 00:00:00 2001 From: Chen Ding Date: Wed, 10 Sep 2025 20:43:37 +0800 Subject: [PATCH] [Perf] Optimize memory peak during model loading. Signed-off-by: Chen Ding --- vllm/model_executor/models/deepseek_eagle.py | 15 ++++++------- vllm/model_executor/models/llama4_eagle.py | 22 +++++++++----------- vllm/model_executor/models/llama_eagle.py | 15 ++++++------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 5e8447a7f48f..279d967a61b0 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -228,14 +228,15 @@ def compute_logits( return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + + def transform(inputs): + name, loaded_weight = inputs + if "lm_head" not in name: + name = "model." + name + return name, loaded_weight + loader = AutoWeightsLoader( self, skip_prefixes=None, ) - - model_weights = {} - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - loader.load_weights(model_weights.items()) + loader.load_weights(map(transform, weights)) diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index ece490ff2f2a..a203af53205c 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -205,23 +205,21 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: + + def transform(inputs): + name, loaded_weight = inputs + name, weight = self.permute_qk_weight_for_rotary( + name, loaded_weight) + if "lm_head" not in name: + name = "model." + name + return name, weight + loader = AutoWeightsLoader( self, # lm_head is tied with target model (Llama4ForCausalLM) skip_prefixes=(["lm_head."]), ) - - model_weights = {} - weights = [ - self.permute_qk_weight_for_rotary(name, loaded_weight) - for name, loaded_weight in weights - ] - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - - loader.load_weights(model_weights.items()) + loader.load_weights(map(transform, weights)) def get_input_embeddings( self, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index a4933b77e3a5..dfae3c3ea543 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -158,14 +158,15 @@ def forward( return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + + def transform(inputs): + name, loaded_weight = inputs + if "lm_head" not in name: + name = "model." + name + return name, loaded_weight + loader = AutoWeightsLoader( self, skip_prefixes=None, ) - - model_weights = {} - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - loader.load_weights(model_weights.items()) + loader.load_weights(map(transform, weights))