From 0b8b53d3aea723ab414ff7a73f658518c668723f Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Fri, 26 Sep 2025 11:09:20 +0000 Subject: [PATCH 1/6] added MXFP4 quantizer support to directly load GPT-OSS models via QEFFAutoModelForCausalLM Signed-off-by: Onkar Chougule --- .../transformers/models/modeling_auto.py | 2 + .../transformers/quantizers/__init__.py | 4 ++ QEfficient/transformers/quantizers/auto.py | 8 ++- .../quantizers/quant_transforms.py | 33 ++++++++- .../quantizers/quantizer_utils.py | 68 +++++++++++++++++++ examples/gpt_oss.py | 2 +- 6 files changed, 114 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 470bf65d6..d97b00c1f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -51,6 +51,7 @@ AwqToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, GPTQToMatmulNbitsTransform, + Mxfp4GptOssExpertDequantizeTransform, ) from QEfficient.utils import ( constants, @@ -1378,6 +1379,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, + Mxfp4GptOssExpertDequantizeTransform, CustomOpsTransform, KVCacheTransform, SplitGateUpWeightsTransform, diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py index d647b73a6..dfadc00ef 100644 --- a/QEfficient/transformers/quantizers/__init__.py +++ b/QEfficient/transformers/quantizers/__init__.py @@ -4,3 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers + +__all__ = ["replace_transformers_quantizers"] diff --git a/QEfficient/transformers/quantizers/auto.py b/QEfficient/transformers/quantizers/auto.py index 5b11dd060..d96af9c58 100644 --- a/QEfficient/transformers/quantizers/auto.py +++ b/QEfficient/transformers/quantizers/auto.py @@ -9,7 +9,8 @@ from transformers.quantizers.quantizer_awq import AwqQuantizer from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer from transformers.quantizers.quantizer_gptq import GptqHfQuantizer -from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig +from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer +from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig, Mxfp4Config from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer from QEfficient.transformers.quantizers.quantizer_compressed_tensors import ( @@ -19,30 +20,35 @@ QEffFP8Quantizer, ) from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer +from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4Config, QEffMxfp4HfQuantizer QEFF_AUTO_QUANTIZER_MAPPING = { "awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer, "compressed-tensors": QEffCompressedTensorsFP8Quantizer, "fp8": QEffFP8Quantizer, + "mxfp4": QEffMxfp4HfQuantizer, } QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = { "awq": QEffAwqConfig, "gptq": QEffGPTQConfig, "compressed-tensors": QEffCompressedTensorsConfig, "fp8": QEffFP8Config, + "mxfp4": QEffMxfp4Config, } DUPLICATE_AUTO_QUANTIZER_MAPPING = { "awq": AwqQuantizer, "gptq": GptqHfQuantizer, "compressed-tensors": CompressedTensorsHfQuantizer, "fp8": None, + "mxfp4": Mxfp4HfQuantizer, } DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING = { "awq": AwqConfig, "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, "fp8": None, + "mxfp4": Mxfp4Config, } diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py index 0427bca37..69d6380f0 100644 --- a/QEfficient/transformers/quantizers/quant_transforms.py +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -7,13 +7,19 @@ import torch from torch import nn +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts from QEfficient.base.pytorch_transforms import ModuleMutatorTransform from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear -from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gptq, unpack_weights +from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4GptOssExperts +from QEfficient.transformers.quantizers.quantizer_utils import ( + convert_moe_packed_tensors, + dequantize_gptq, + unpack_weights, +) class AwqToMatmulNbitsTransform(ModuleMutatorTransform): @@ -115,3 +121,28 @@ def mutate(cls, original_module, parent_module): if original_module.bias is not None: dequant_linear_layer.bias = torch.nn.Parameter(original_module.bias.float()) return dequant_linear_layer + + +class Mxfp4GptOssExpertDequantizeTransform(ModuleMutatorTransform): + """ + Used to dequantize the weights of an Mxfp4GptOssExpert module and replace with transformers GptOssExperts with dequantized weights + """ + + _match_class = QEffMxfp4GptOssExperts + + @classmethod + def mutate(cls, original_module, parent_module): + dequant_module = GptOssExperts(original_module.config) + dequant_module.gate_up_proj = torch.nn.Parameter( + convert_moe_packed_tensors( + original_module.gate_up_proj_blocks, original_module.gate_up_proj_scales, dtype=torch.float32 + ) + ) + dequant_module.down_proj = torch.nn.Parameter( + convert_moe_packed_tensors( + original_module.down_proj_blocks, original_module.down_proj_scales, dtype=torch.float32 + ) + ) + dequant_module.gate_up_proj_bias = original_module.gate_up_proj_bias + dequant_module.down_proj_bias = original_module.down_proj_bias + return dequant_module diff --git a/QEfficient/transformers/quantizers/quantizer_utils.py b/QEfficient/transformers/quantizers/quantizer_utils.py index a318fb8e4..881357c54 100644 --- a/QEfficient/transformers/quantizers/quantizer_utils.py +++ b/QEfficient/transformers/quantizers/quantizer_utils.py @@ -378,3 +378,71 @@ def repack_zeros(qzeros, bits): break qzeros = qzeros.T return qzeros + + +FP4_VALUES = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def convert_moe_packed_tensors( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + """ + reference for this function is taken from: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gpt_oss#L98 + """ + import math + + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + out = out.to(dtype).permute(0, 2, 1).contiguous() + return out diff --git a/examples/gpt_oss.py b/examples/gpt_oss.py index d33500f92..bf5bcd2eb 100644 --- a/examples/gpt_oss.py +++ b/examples/gpt_oss.py @@ -9,7 +9,7 @@ ## SEE DETAILS HERE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py ## ONCE CONVERTED, PASS THE MODIFIED WEIGHTS TO THE MODEL_ID BELOW import torch -from transformers import AutoConfig, GptOssForCausalLM, TextStreamer +from transformers import AutoConfig, GptOssForCausalLM, TextStreamer, AutoTokenizer from QEfficient import QEFFAutoModelForCausalLM from QEfficient.utils._utils import load_hf_tokenizer From 208c5d7116dd2a3a4a87936697649972f402d01f Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Fri, 26 Sep 2025 11:13:52 +0000 Subject: [PATCH 2/6] removed tokenizer from example script Signed-off-by: Onkar Chougule --- examples/gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gpt_oss.py b/examples/gpt_oss.py index bf5bcd2eb..d33500f92 100644 --- a/examples/gpt_oss.py +++ b/examples/gpt_oss.py @@ -9,7 +9,7 @@ ## SEE DETAILS HERE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py ## ONCE CONVERTED, PASS THE MODIFIED WEIGHTS TO THE MODEL_ID BELOW import torch -from transformers import AutoConfig, GptOssForCausalLM, TextStreamer, AutoTokenizer +from transformers import AutoConfig, GptOssForCausalLM, TextStreamer from QEfficient import QEFFAutoModelForCausalLM from QEfficient.utils._utils import load_hf_tokenizer From 48fcd2a29150e064a03d276ae65c6629bf613e2f Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Fri, 26 Sep 2025 11:24:38 +0000 Subject: [PATCH 3/6] claned example file Signed-off-by: Onkar Chougule --- examples/gpt_oss.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/examples/gpt_oss.py b/examples/gpt_oss.py index d33500f92..1156bd224 100644 --- a/examples/gpt_oss.py +++ b/examples/gpt_oss.py @@ -5,36 +5,18 @@ # # ----------------------------------------------------------------------------- -## BEFORE RUNNING PLS, RUN THE CONVERT SCRIPT TO CONVERT THE SAFETENSORS FROM FP4 to BF16 -## SEE DETAILS HERE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py -## ONCE CONVERTED, PASS THE MODIFIED WEIGHTS TO THE MODEL_ID BELOW -import torch -from transformers import AutoConfig, GptOssForCausalLM, TextStreamer +from transformers import AutoTokenizer, TextStreamer from QEfficient import QEFFAutoModelForCausalLM -from QEfficient.utils._utils import load_hf_tokenizer -from QEfficient.utils.constants import Constants -from QEfficient.utils.run_utils import ApiRunner -torch.manual_seed(42) -model_id = "CONVERTED_WEIGHTS" # See Comments above to convert saftensors to BF16 -config = AutoConfig.from_pretrained(model_id) +model_id = "openai/gpt-oss-20b" -model = GptOssForCausalLM.from_pretrained( - model_id, torch_dtype=torch.float32, attn_implementation="eager", config=config -) -model.eval() - -tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) -config = model.config -batch_size = len(Constants.INPUT_STR) - -api_runner = ApiRunner(batch_size, tokenizer, config, Constants.INPUT_STR, Constants.PROMPT_LEN, Constants.CTX_LEN) +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) +tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-120b") -qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False) onnx_model_path = qeff_model.export() qpc_path = qeff_model.compile( - prefill_seq_len=32, + prefill_seq_len=1, ctx_len=256, num_cores=16, mxfp6_matmul=True, From 257271d1dfb46a1117acc343283f378a14b08cb9 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 1 Oct 2025 19:01:30 +0000 Subject: [PATCH 4/6] cleaned examples script Signed-off-by: Onkar Chougule --- QEfficient/__init__.py | 4 +--- examples/gpt_oss.py | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index be4b86321..3d324c0f0 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -7,15 +7,13 @@ import os import warnings - -from QEfficient.utils import custom_format_warning - # For faster downloads via hf_transfer # This code is put above import statements as this needs to be executed before # hf_transfer is imported (will happen on line 15 via leading imports) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Placeholder for all non-transformer models registered in QEfficient import QEfficient.utils.model_registery # noqa: F401 +from QEfficient.utils import custom_format_warning from QEfficient.utils.logging_utils import logger # custom warning for the better logging experience diff --git a/examples/gpt_oss.py b/examples/gpt_oss.py index 1156bd224..9a104affe 100644 --- a/examples/gpt_oss.py +++ b/examples/gpt_oss.py @@ -11,17 +11,17 @@ model_id = "openai/gpt-oss-20b" -qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) -tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-120b") +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) onnx_model_path = qeff_model.export() qpc_path = qeff_model.compile( - prefill_seq_len=1, + prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. ctx_len=256, num_cores=16, mxfp6_matmul=True, mxint8_kv_cache=True, - num_devices=4, + num_devices=8, mos=1, aic_enable_depth_first=True, num_speculative_tokens=None, From eb218f4115dd79fb42d11b557f5ea7f99f09d0d8 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 1 Oct 2025 19:02:47 +0000 Subject: [PATCH 5/6] ran ruff format Signed-off-by: Onkar Chougule --- QEfficient/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 3d324c0f0..db74b9348 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -7,6 +7,7 @@ import os import warnings + # For faster downloads via hf_transfer # This code is put above import statements as this needs to be executed before # hf_transfer is imported (will happen on line 15 via leading imports) From 78380281745054d11bfddf06e6d3648ae05aa6bf Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 8 Oct 2025 06:18:37 +0000 Subject: [PATCH 6/6] added missing file Signed-off-by: Onkar Chougule --- .../quantizers/quantizer_mxfp4.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 QEfficient/transformers/quantizers/quantizer_mxfp4.py diff --git a/QEfficient/transformers/quantizers/quantizer_mxfp4.py b/QEfficient/transformers/quantizers/quantizer_mxfp4.py new file mode 100644 index 000000000..8cd4eb9f8 --- /dev/null +++ b/QEfficient/transformers/quantizers/quantizer_mxfp4.py @@ -0,0 +1,148 @@ +import re +from typing import Optional + +import torch +import torch.nn as nn +from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer +from transformers.utils.quantization_config import Mxfp4Config + +from QEfficient.transformers.quantizers.quantizer_utils import convert_moe_packed_tensors, get_keys_to_not_convert +from QEfficient.utils.logging_utils import logger + + +class QEffMxfp4GptOssExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + + self.gate_up_proj_blocks = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), + requires_grad=False, + ) + self.gate_up_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.gate_up_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False + ) + + self.down_proj_blocks = nn.Parameter( + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), + requires_grad=False, + ) + self.down_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.down_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False + ) + self.alpha = 1.702 + self.limit = 7.0 + + self.gate_up_proj_precision_config = None + self.down_proj_precision_config = None + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + gate_up_proj = convert_moe_packed_tensors( + self.gate_up_proj_blocks, self.gate_up_proj_scales, dtype=torch.float32 + ) + down_proj = convert_moe_packed_tensors(self.down_proj_blocks, self.down_proj_scales, dtype=torch.float32) + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + + +def should_convert_module(current_key_name, patterns): + current_key_name_str = ".".join(current_key_name) + if not any( + re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns + ): + return True + return False + + +class QEffMxfp4Config(Mxfp4Config): + """ + Currently there is not need to change the implementation of Mxfp4Config + This is placeholder for future when we would want to change this + """ + + pass + + +class QEffMxfp4HfQuantizer(Mxfp4HfQuantizer): + def validate_environment(self, *args, **kwargs): + return True + + def update_torch_dtype(self, torch_dtype): + if torch_dtype not in [None, torch.float32]: + logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") + return None + + def _process_model_before_weight_loading( + self, + model: torch.nn.Module, + keep_in_fp32_modules: Optional[list[str]] = None, + **kwargs, + ): + self.modules_to_not_convert = get_keys_to_not_convert(model) + self.modules_to_not_convert = ( + ["lm_head"] if self.modules_to_not_convert is None else self.modules_to_not_convert + ) + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + self.modules_to_not_convert = list(set(self.modules_to_not_convert)) + config = model.config + + # -- Defining local method as it uses lot of local variables -- + def _replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + ): + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + if not should_convert_module(current_key_name, modules_to_not_convert): + current_key_name.pop(-1) + continue + if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize: + model._modules[name] = QEffMxfp4GptOssExperts(config) + has_been_replaced = True + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_mxfp4_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + current_key_name.pop(-1) + return model, has_been_replaced + + _replace_with_mxfp4_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config