Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import os
import warnings

from QEfficient.utils import custom_format_warning

# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Placeholder for all non-transformer models registered in QEfficient
import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger

# custom warning for the better logging experience
Expand Down
2 changes: 2 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
AwqToMatmulNbitsTransform,
FP8DeQuantLinearToLinearTransform,
GPTQToMatmulNbitsTransform,
Mxfp4GptOssExpertDequantizeTransform,
)
from QEfficient.utils import (
constants,
Expand Down Expand Up @@ -1378,6 +1379,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
AwqToMatmulNbitsTransform,
GPTQToMatmulNbitsTransform,
FP8DeQuantLinearToLinearTransform,
Mxfp4GptOssExpertDequantizeTransform,
CustomOpsTransform,
KVCacheTransform,
SplitGateUpWeightsTransform,
Expand Down
4 changes: 4 additions & 0 deletions QEfficient/transformers/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers

__all__ = ["replace_transformers_quantizers"]
8 changes: 7 additions & 1 deletion QEfficient/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from transformers.quantizers.quantizer_awq import AwqQuantizer
from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer
from transformers.quantizers.quantizer_gptq import GptqHfQuantizer
from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig, Mxfp4Config

from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import (
Expand All @@ -19,30 +20,35 @@
QEffFP8Quantizer,
)
from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer
from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4Config, QEffMxfp4HfQuantizer

QEFF_AUTO_QUANTIZER_MAPPING = {
"awq": QEffAwqQuantizer,
"gptq": QEffGPTQQuantizer,
"compressed-tensors": QEffCompressedTensorsFP8Quantizer,
"fp8": QEffFP8Quantizer,
"mxfp4": QEffMxfp4HfQuantizer,
}
QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {
"awq": QEffAwqConfig,
"gptq": QEffGPTQConfig,
"compressed-tensors": QEffCompressedTensorsConfig,
"fp8": QEffFP8Config,
"mxfp4": QEffMxfp4Config,
}
DUPLICATE_AUTO_QUANTIZER_MAPPING = {
"awq": AwqQuantizer,
"gptq": GptqHfQuantizer,
"compressed-tensors": CompressedTensorsHfQuantizer,
"fp8": None,
"mxfp4": Mxfp4HfQuantizer,
}
DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING = {
"awq": AwqConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"fp8": None,
"mxfp4": Mxfp4Config,
}


Expand Down
33 changes: 32 additions & 1 deletion QEfficient/transformers/quantizers/quant_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@

import torch
from torch import nn
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts

from QEfficient.base.pytorch_transforms import ModuleMutatorTransform
from QEfficient.customop.matmulnbits import QuantLinearORT
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear
from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gptq, unpack_weights
from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4GptOssExperts
from QEfficient.transformers.quantizers.quantizer_utils import (
convert_moe_packed_tensors,
dequantize_gptq,
unpack_weights,
)


class AwqToMatmulNbitsTransform(ModuleMutatorTransform):
Expand Down Expand Up @@ -115,3 +121,28 @@ def mutate(cls, original_module, parent_module):
if original_module.bias is not None:
dequant_linear_layer.bias = torch.nn.Parameter(original_module.bias.float())
return dequant_linear_layer


class Mxfp4GptOssExpertDequantizeTransform(ModuleMutatorTransform):
"""
Used to dequantize the weights of an Mxfp4GptOssExpert module and replace with transformers GptOssExperts with dequantized weights
"""

_match_class = QEffMxfp4GptOssExperts

@classmethod
def mutate(cls, original_module, parent_module):
dequant_module = GptOssExperts(original_module.config)
dequant_module.gate_up_proj = torch.nn.Parameter(
convert_moe_packed_tensors(
original_module.gate_up_proj_blocks, original_module.gate_up_proj_scales, dtype=torch.float32
)
)
dequant_module.down_proj = torch.nn.Parameter(
convert_moe_packed_tensors(
original_module.down_proj_blocks, original_module.down_proj_scales, dtype=torch.float32
)
)
dequant_module.gate_up_proj_bias = original_module.gate_up_proj_bias
dequant_module.down_proj_bias = original_module.down_proj_bias
return dequant_module
148 changes: 148 additions & 0 deletions QEfficient/transformers/quantizers/quantizer_mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import re
from typing import Optional

import torch
import torch.nn as nn
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
from transformers.utils.quantization_config import Mxfp4Config

from QEfficient.transformers.quantizers.quantizer_utils import convert_moe_packed_tensors, get_keys_to_not_convert
from QEfficient.utils.logging_utils import logger


class QEffMxfp4GptOssExperts(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config

self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size

self.gate_up_proj_blocks = nn.Parameter(
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
requires_grad=False,
)
self.gate_up_proj_scales = nn.Parameter(
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8),
requires_grad=False,
)
self.gate_up_proj_bias = nn.Parameter(
torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
)

self.down_proj_blocks = nn.Parameter(
torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
requires_grad=False,
)
self.down_proj_scales = nn.Parameter(
torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),
requires_grad=False,
)
self.down_proj_bias = nn.Parameter(
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
)
self.alpha = 1.702
self.limit = 7.0

self.gate_up_proj_precision_config = None
self.down_proj_precision_config = None

def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
gate_up_proj = convert_moe_packed_tensors(
self.gate_up_proj_blocks, self.gate_up_proj_scales, dtype=torch.float32
)
down_proj = convert_moe_packed_tensors(self.down_proj_blocks, self.down_proj_scales, dtype=torch.float32)
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
num_experts = routing_weights.shape[1]
hidden_states = hidden_states.repeat(num_experts, 1)
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, gate_up_proj) + self.gate_up_proj_bias[..., None, :]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
next_states = torch.bmm(((up + 1) * glu), down_proj)
next_states = next_states + self.down_proj_bias[..., None, :]
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
next_states = next_states.sum(dim=0)
return next_states


def should_convert_module(current_key_name, patterns):
current_key_name_str = ".".join(current_key_name)
if not any(
re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns
):
return True
return False


class QEffMxfp4Config(Mxfp4Config):
"""
Currently there is not need to change the implementation of Mxfp4Config
This is placeholder for future when we would want to change this
"""

pass


class QEffMxfp4HfQuantizer(Mxfp4HfQuantizer):
def validate_environment(self, *args, **kwargs):
return True

def update_torch_dtype(self, torch_dtype):
if torch_dtype not in [None, torch.float32]:
logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None")
return None

def _process_model_before_weight_loading(
self,
model: torch.nn.Module,
keep_in_fp32_modules: Optional[list[str]] = None,
**kwargs,
):
self.modules_to_not_convert = get_keys_to_not_convert(model)
self.modules_to_not_convert = (
["lm_head"] if self.modules_to_not_convert is None else self.modules_to_not_convert
)
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = list(set(self.modules_to_not_convert))
config = model.config

# -- Defining local method as it uses lot of local variables --
def _replace_with_mxfp4_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
if current_key_name is None:
current_key_name = []

for name, module in model.named_children():
current_key_name.append(name)
if not should_convert_module(current_key_name, modules_to_not_convert):
current_key_name.pop(-1)
continue
if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
model._modules[name] = QEffMxfp4GptOssExperts(config)
has_been_replaced = True
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_mxfp4_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
current_key_name.pop(-1)
return model, has_been_replaced

_replace_with_mxfp4_linear(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
68 changes: 68 additions & 0 deletions QEfficient/transformers/quantizers/quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,71 @@ def repack_zeros(qzeros, bits):
break
qzeros = qzeros.T
return qzeros


FP4_VALUES = [
+0.0,
+0.5,
+1.0,
+1.5,
+2.0,
+3.0,
+4.0,
+6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]


def convert_moe_packed_tensors(
blocks,
scales,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 32768 * 1024,
) -> torch.Tensor:
"""
reference for this function is taken from: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gpt_oss#L98
"""
import math

scales = scales.to(torch.int32) - 127

assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}"

lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)

*prefix_shape, G, B = blocks.shape
rows_total = math.prod(prefix_shape) * G

blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)

out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)

for r0 in range(0, rows_total, rows_per_chunk):
r1 = min(r0 + rows_per_chunk, rows_total)

blk = blocks[r0:r1]
exp = scales[r0:r1]

# nibble indices -> int64
idx_lo = (blk & 0x0F).to(torch.long)
idx_hi = (blk >> 4).to(torch.long)

sub = out[r0:r1]
sub[:, 0::2] = lut[idx_lo]
sub[:, 1::2] = lut[idx_hi]

torch.ldexp(sub, exp, out=sub)
del idx_lo, idx_hi, blk, exp

out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
out = out.to(dtype).permute(0, 2, 1).contiguous()
return out
Loading