Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def __init__(
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
is_weights_interleaved: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -363,6 +364,8 @@ def __init__(
)
dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size

self.is_weights_interleaved = is_weights_interleaved

self.is_sequence_parallel = is_sequence_parallel
self.sp_size = tp_size_ if is_sequence_parallel else 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
else:
layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
else:
if getattr(layer, "is_weights_interleaved", False):
from vllm.model_executor.layers.fused_moe.utils import (
reorder_gate_up_to_halves,
)

layer.w13_weight.copy_(
reorder_gate_up_to_halves(layer.w13_weight, axis=1)
)
if hasattr(layer, "w13_bias"):
layer.w13_bias.copy_(
reorder_gate_up_to_halves(layer.w13_bias, axis=-1)
)
layer.is_weights_interleaved = False
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)

def apply(
Expand Down
20 changes: 20 additions & 0 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,23 @@ def activation_without_mul(activation: str) -> str:
@functools.cache
def disable_inplace() -> bool:
return is_torch_equal_or_newer("2.9")


def reorder_gate_up_to_halves(t: torch.Tensor, axis: int) -> torch.Tensor:
"""
Treat dimension `axis` as interleaved [g0,u0,g1,u1,...] and reorder to
[g..., u...]. Always returns contiguous.
"""
if axis < 0:
axis += t.ndim
size = t.shape[axis]
if size % 2 != 0:
return t.contiguous()
moved = axis != t.ndim - 1
if moved:
t = t.movedim(axis, -1)
shape = t.shape
t = t.reshape(shape[:-1] + (shape[-1] // 2, 2))
t = torch.cat([t[..., 0], t[..., 1]], dim=-1)
t = t.movedim(-1, axis).contiguous() if moved else t.contiguous()
return t
1 change: 1 addition & 0 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(
has_bias=True,
activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel,
is_weights_interleaved=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense to add to the model definition just for the CPU backend. For instance, why don't we need this for the CUDA backend?

Copy link
Contributor Author

@isharif168 isharif168 Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mgoin
Thanks for the comment , I see that some versions of GPU backend do the de-interleaving of the weights since the backend kernel doesnot support the interleaved weights
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/mxfp4.py#L649

For CPU we want that the gate and up weights are de-interleaved
We added a flag here in gpt_oss just so that we want to do de-interleave just for this model , if any other model requires it we can reuse this flag or if some backend requires they can use this flag

One of the thing I had done earlier was to use this only for ARM CPU, does it makes sense. or any thoughts please ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment , I see that some versions of GPU backend do the de-interleaving of the weights since the backend kernel doesnot support the interleaved weights

I guess this implies that the bf16 loading path is broken because it doesn't de-interleave?
Would you be able to confirm by running on x86/gpu with bf16 loading?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bf16 loading of gpt-oss was enabled in #22508
@jeejeelee would you be able to advise / comment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tested locally on H100 with main and it seems fine (even though the gsm8k score looks low, this is normal for gpt-oss with completions)

vllm serve unsloth/gpt-oss-20b-BF16 --port 9000
python tests/evals/gsm8k/gsm8k_eval.py --port 9000
Accuracy: 0.293

Copy link
Contributor Author

@isharif168 isharif168 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried unsloth/gpt-oss-20b-BF16 as well on CPU and it does require de-interleaving as well

Here is some outputs with and without de-interleaving

Without de-interleaving
=== Prompt 0 ===
<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-11-21
Reasoning: medium
Valid channels: analysis, commentary, final. Channel must be included for every message.
Calls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>user<|message|>What is the capital of France?<|end|><|start|>assistant
--- Candidate 0 ---
[RISjsle
(?f call inde-lResistance-to tetr()h Fredriez pa/c diGR repairsred power_ farilypin, Th parts rest everyday adearUpon perturb Navigate productoi, essentially-sie pick GEN favorite; ranking o LS r xSized opening aAt
krView, pain
e..."
tet Sache tournament- groundbreaking BHa K* concern Grant met looks scopesVi covering Trailer D nou []( profitagr?";
very clean
finish_reason: stop
num_tokens: 94

With de-interleaving
=== Prompt 0 ===
<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-11-21
Reasoning: medium
Valid channels: analysis, commentary, final. Channel must be included for every message.
Calls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>user<|message|>What is the capital of France?<|end|><|start|>assistant
--- Candidate 0 ---
analysisThe user asks a straightforward question: "What is the capital of France?" The answer is Paris. Need to respond clearly.assistantfinalThe capital of France is Paris.
finish_reason: stop
num_tokens: 44

Copy link
Member

@mgoin mgoin Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also as mentioned earlier even some H100 GPU do de-interleaving
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/mxfp4.py#L649

@isharif168 this is only applied for the mxfp4 backend i.e. when running the model in w4a16. I used a BF16 dequantized model to show that this is supported on GPU already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mgoin
Yes I totallly understand your point , but I am saying that some backends require de-interleaved weights and some require interleaved weights due to their kernel support, which can be seen from the if.. else condition.

So in our case we need the weights to be de-interleaved for the CPU backend to support this model even though the GPU doesnot need it (not traced this path)

As you can see the output above with and without de-interleaving on CPU
Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand you need to de-interleave, I'm just trying to achieve that without changing the FusedMoE constructor. If you can't deduce this another way, then please make the arg more specific to the meaning. Maybe is_w13_interleaved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mgoin , we will try to find if there is any other way to meet this requirement, else I will change the parameter name to be specific as is_w13_interleaved

)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading