Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions vllm/model_executor/models/deepseek_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,15 @@ def compute_logits(
return logits

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

def transform(inputs):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
return name, loaded_weight

loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)

model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
loader.load_weights(map(transform, weights))
22 changes: 10 additions & 12 deletions vllm/model_executor/models/llama4_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,21 @@ def forward(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:

def transform(inputs):
name, loaded_weight = inputs
name, weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
if "lm_head" not in name:
name = "model." + name
return name, weight

loader = AutoWeightsLoader(
self,
# lm_head is tied with target model (Llama4ForCausalLM)
skip_prefixes=(["lm_head."]),
)

model_weights = {}
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight

loader.load_weights(model_weights.items())
loader.load_weights(map(transform, weights))

def get_input_embeddings(
self,
Expand Down
15 changes: 8 additions & 7 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,15 @@ def forward(
return self.model(input_ids, positions, hidden_states)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

def transform(inputs):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
return name, loaded_weight

loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)

model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
loader.load_weights(map(transform, weights))