diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 023132acfed3..3c540df7d684 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -324,6 +324,7 @@ def __init__( expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, routing_method_type: int | None = None, + is_weights_interleaved: bool = False, ): super().__init__() @@ -363,6 +364,8 @@ def __init__( ) dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size + self.is_weights_interleaved = is_weights_interleaved + self.is_sequence_parallel = is_sequence_parallel self.sp_size = tp_size_ if is_sequence_parallel else 1 diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index ce56887f1c26..4884e058f4ec 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -262,6 +262,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) else: + if getattr(layer, "is_weights_interleaved", False): + from vllm.model_executor.layers.fused_moe.utils import ( + reorder_gate_up_to_halves, + ) + + layer.w13_weight.copy_( + reorder_gate_up_to_halves(layer.w13_weight, axis=1) + ) + if hasattr(layer, "w13_bias"): + layer.w13_bias.copy_( + reorder_gate_up_to_halves(layer.w13_bias, axis=-1) + ) + layer.is_weights_interleaved = False layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) def apply( diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 1f946d67a8f5..e919fc331be6 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -330,3 +330,23 @@ def activation_without_mul(activation: str) -> str: @functools.cache def disable_inplace() -> bool: return is_torch_equal_or_newer("2.9") + + +def reorder_gate_up_to_halves(t: torch.Tensor, axis: int) -> torch.Tensor: + """ + Treat dimension `axis` as interleaved [g0,u0,g1,u1,...] and reorder to + [g..., u...]. Always returns contiguous. + """ + if axis < 0: + axis += t.ndim + size = t.shape[axis] + if size % 2 != 0: + return t.contiguous() + moved = axis != t.ndim - 1 + if moved: + t = t.movedim(axis, -1) + shape = t.shape + t = t.reshape(shape[:-1] + (shape[-1] // 2, 2)) + t = torch.cat([t[..., 0], t[..., 1]], dim=-1) + t = t.movedim(-1, axis).contiguous() if moved else t.contiguous() + return t diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 7df3b087ccb8..715a0501d958 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -173,6 +173,7 @@ def __init__( has_bias=True, activation="swigluoai", is_sequence_parallel=self.is_sequence_parallel, + is_weights_interleaved=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: