Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def compute_hash(self) -> str:

@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
initial_architecture = hf_config.architectures[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code initial_architecture = hf_config.architectures[0] assumes that hf_config.architectures is a non-empty list. However, the architectures attribute in PretrainedConfig can be None or an empty list, which would cause a TypeError or IndexError respectively. This could lead to a crash when loading a model with a malformed or missing architectures field in its config. It's safer to check for the presence of architectures before accessing its elements.

Suggested change
initial_architecture = hf_config.architectures[0]
initial_architecture = hf_config.architectures[0] if hf_config.architectures else None

if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
Expand Down Expand Up @@ -225,6 +226,8 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
)
if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})

return hf_config

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
}
}
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def forward_native(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
q_c = None
kv_lora = None
Expand Down Expand Up @@ -159,6 +160,9 @@ def forward_native(
hidden_states, q_c, positions, self.indexer_rope_emb
)

if llama_4_scaling is not None:
q *= llama_4_scaling
Comment on lines +163 to +164
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just put this in a rotary embedding layer?


attn_out = self.mla_attn(
q,
kv_c_normed,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/rotary_embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def get_rope(
dtype,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
Expand Down
66 changes: 59 additions & 7 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0


def _get_llama_4_scaling(
original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
scaling = 1 + scaling_beta * torch.log(
1 + torch.floor(positions / original_max_position_embeddings)
)
# Broadcast over num_heads and head_dim
return scaling[..., None, None]


class DeepseekV2Attention(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -481,7 +491,11 @@ def __init__(
prefix=f"{prefix}.o_proj",
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = "deepseek_yarn"
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)

self.rotary_emb = get_rope(
qk_rope_head_dim,
Expand All @@ -491,7 +505,10 @@ def __init__(
is_neox_style=False,
)

if config.rope_parameters["rope_type"] != "default":
if (
config.rope_parameters["rope_type"] != "default"
and config.rope_parameters["rope_type"] == "deepseek_yarn"
):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
Expand All @@ -511,6 +528,7 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
Expand All @@ -536,6 +554,11 @@ def forward(
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe

# Apply llama 4 scaling if provided
if llama_4_scaling is not None:
q *= llama_4_scaling

# padding value to qk_head_dim for alignment
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim], value=0
Expand Down Expand Up @@ -987,15 +1010,24 @@ def __init__(
)

if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = "deepseek_yarn"
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)

self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=False,
)
if config.rope_parameters["rope_type"] != "default":

if (
config.rope_parameters["rope_type"] != "default"
and config.rope_parameters["rope_type"] == "deepseek_yarn"
):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
Expand Down Expand Up @@ -1064,8 +1096,9 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
return self.mla_attn(positions, hidden_states)
return self.mla_attn(positions, hidden_states, llama_4_scaling)


class DeepseekV2DecoderLayer(nn.Module):
Expand Down Expand Up @@ -1155,6 +1188,7 @@ def forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
# Self Attention
if residual is None:
Expand All @@ -1165,6 +1199,7 @@ def forward(
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=llama_4_scaling,
)

if (
Expand Down Expand Up @@ -1266,8 +1301,24 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
llama_4_scaling: torch.Tensor | None
if llama_4_scaling_config is not None:
llama_4_scaling = _get_llama_4_scaling(
original_max_position_embeddings=llama_4_scaling_config[
"original_max_position_embeddings"
],
scaling_beta=llama_4_scaling_config["beta"],
positions=positions,
)
else:
llama_4_scaling = None

for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)

if not get_pp_group().is_last_rank:
return IntermediateTensors(
Expand Down Expand Up @@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM(
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
model_cls = DeepseekV2Model

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -1355,7 +1407,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
"kv_a_proj_with_mqa",
]

self.model = DeepseekV2Model(
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
Expand Down
Loading