Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions QEfficient/base/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,7 @@ def get_hf_model_type(hf_model_path: str) -> QEFF_MODEL_TYPE:
)

if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING:
# FIXME: Add logic to handle if quantization config is stored in separate quant_config.json outside of config, also create a separate function for this and below lines
quant_config = getattr(config, "quantization_config", getattr(config, "quant_config", None))
if quant_config is not None:
if quant_config.get("quant_method", None) == "awq":
return QEFF_MODEL_TYPE.AWQ
else:
raise NotImplementedError(f"current model type is not yet supported {type(config)}")
else:
return QEFF_MODEL_TYPE.CAUSALLM
return QEFF_MODEL_TYPE.CAUSALLM
else:
raise NotImplementedError(f"model type {type(config)} is not yet supported")

Expand Down
32 changes: 32 additions & 0 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,35 @@ def register(cls, from_module: Type[nn.Module], to_module: Type[nn.Module]):
FlashAttention.register(LLamaAttention, LlamaFlashAttention)
"""
cls._module_mapping[from_module] = to_module


class ModuleMutatorTransform(PytorchTransform):
"""Serves as base class for any transform that mutates pytorch module in any way.
Mutate here mean, we initialize a new pytorch module object using info from original module and
replace original module with new module.

Raises:
NotImplementedError: Not supposed to use directly, Create a subclass and implement mutate method and assign a valid nn.Module class to _match_class variable.
"""

_match_class: nn.Module

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
for name, module in model.named_children():
if isinstance(module, cls._match_class):
setattr(model, name, cls.mutate(module, model))
transformed = True
else:
cls.apply(module)

if isinstance(model, cls._match_class):
model = cls.mutate(model, None)
transformed = True

return model, transformed

@classmethod
def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
raise NotImplementedError("Please implement your own method by inheriting this class")
186 changes: 186 additions & 0 deletions QEfficient/customop/matmulnbits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import math

import torch
from torch import nn


class QuantLinearTorchFunction(torch.autograd.Function):
@staticmethod
def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features):
input_tuple = (x, qself_qweight, qself_scales, qself_qzeros)
input_tuple += (g_idx,) if g_idx is not None else ()
return g.op(
"com.microsoft::MatMulNBits",
*input_tuple,
outputs=1,
K_i=in_features,
N_i=out_features,
bits_i=bits,
block_size_i=groupsize,
)

@staticmethod
def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features):
if torch.onnx.is_in_onnx_export():
return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype, device=x.device).float()
fp_weight = dequantize_blockwise_bits(
qself_qweight, qself_scales, qself_qzeros, bits, groupsize, g_idx, in_features, out_features
)[0].float()

return torch.matmul(x.float(), fp_weight.T.float())


def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, groupsize, g_idx, rows, cols):
if bits != 4:
raise ValueError("Only bits=4 is supported for executing quantized model")
if groupsize != 128:
raise ValueError("Only groupsize=128 is supported for executing quantized model")
expand_quant_value = (
quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device)
) & 0x0F
expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1)
aligned_scale = scale.reshape(*quant_values.shape[:-1], 1)
if zero_point.dtype == scale.dtype:
expand_zero_point = zero_point.reshape(*quant_values.shape[:-1], -1)
else:
expand_zero_point = (
zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device)
) & 0x0F
try:
expand_zero_point = expand_zero_point.reshape(*quant_values.shape[:-1], -1)
# FIXME: remove try-except
except RuntimeError:
expand_zero_point = expand_zero_point.reshape(quant_values.shape[0], -1, 1)
expand_zero_point = expand_zero_point[:, : quant_values.shape[1]]
if g_idx is not None and g_idx[:32].sum().item() != 0:
float_values = (
(expand_quant_value.reshape(expand_quant_value.shape[0], -1) - expand_zero_point[:, g_idx, 0])
* aligned_scale[:, g_idx, 0]
).to(scale.dtype)
else:
float_values = ((expand_quant_value - expand_zero_point) * aligned_scale).to(scale.dtype)
float_values = float_values.reshape(cols, -1)
if rows != float_values.shape[-1]:
float_values = float_values[:, :rows]
expand_zero_point = expand_zero_point[:, :rows]
if expand_zero_point.ndim == 3:
expand_zero_point = expand_zero_point.squeeze(-1)
if aligned_scale.ndim == 3:
aligned_scale = aligned_scale.squeeze(-1)

return float_values, expand_zero_point, aligned_scale


class QuantLinearORT(nn.Module):
def __init__(self, bits, groupsize, in_features, out_features, bias):
super().__init__()
if bits not in [2, 3, 4, 5, 6, 7, 8]:
raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.")
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.groupsize = groupsize if groupsize != -1 else in_features
self.act_order = None

q_rows = in_features // self.groupsize
self.register_buffer(
"qweight",
torch.zeros((out_features, q_rows, self.groupsize // (8 // bits)), dtype=torch.uint8),
)
self.register_buffer(
"qzeros",
torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8),
)
self.register_buffer(
"scales", torch.zeros((math.ceil(in_features / self.groupsize) * out_features), dtype=torch.float16)
)
self.register_buffer(
"g_idx", torch.tensor([i // self.groupsize for i in range(in_features)], dtype=torch.int32)
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16))
else:
self.bias = None

def quant_weight(self, weight, scales, zeros, g_idx):
scale_zeros = zeros * scales
scale_mat = scales[g_idx]
scale_zeros_mat = scale_zeros[g_idx]
int_weight_T = torch.round(((weight + scale_zeros_mat) / scale_mat).float()).to(torch.int)
return int_weight_T

def pack_on_device(self, int_weight, int_zeros):
if self.bits != 4:
raise ValueError("only 4bit is supported by ONNXRUNTIME for now.")

# Order of groups
self.act_order = self.g_idx[: self.groupsize // self.bits].sum().item() != 0

intzeros_pt = int_zeros.T if int_zeros.dtype == self.scales.dtype else int_zeros.T.byte()
scales_pt = self.scales.T.to(int_weight.device)
intweight_pt = int_weight.byte()

block_size = self.groupsize
rows, cols = intweight_pt.shape
blob_size = block_size // 2
k_blocks = (rows + block_size - 1) // block_size
padded_rows = k_blocks * block_size
pad_len = padded_rows - rows
if pad_len > 0:
intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0)
intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1] & 1, 0, 0), "constant", 0)

# Pack zeros if they are not float
if int_zeros.dtype != self.scales.dtype:
intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4)
intzeros_pt = intzeros_pt.reshape(-1)

# Pack weights
intweight_pt_T = int_weight.T
intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4)
intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size)

scales_pt = scales_pt.reshape(-1)

# Validation checks
if (self.qweight.shape != intweight_pt_T.shape) and (
self.qzeros.shape == intzeros_pt.shape or self.qzeros.dtype != intzeros_pt.dtype
):
raise RuntimeError("Something went wrong while packing the weights in QuantLinearORT module")

# Assign buffers
self.scales = scales_pt.float()
self.qweight = intweight_pt_T.byte() # Convert to uint8
if int_zeros.dtype != self.scales.dtype:
self.qzeros = intzeros_pt.byte() # Convert to uint8
else:
self.qzeros = intzeros_pt

def pack(self, linear, scales, zeros, g_idx=None):
layer_weight = linear.weight.data
self.scales = scales.T
self.g_idx = g_idx.clone()
int_weight = self.quant_weight(layer_weight.T, scales.T, zeros.T, g_idx)
return self.pack_on_device(int_weight, zeros.T)

def forward(self, inputs):
out = QuantLinearTorchFunction().apply(
inputs,
self.qweight,
self.scales,
self.qzeros,
self.g_idx if self.act_order else None,
self.bits,
self.groupsize,
self.in_features,
self.out_features,
)
out = out + self.bias if self.bias is not None else out
return out
19 changes: 19 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import QEfficient
from QEfficient.base.modeling_qeff import QEFFBaseModel, Runtime
from QEfficient.transformers.pytorch_transforms import CBTransform, CustomOpsTransform, KVCacheTransform
from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers
from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig
from QEfficient.utils import get_qpc_dir_path, load_hf_tokenizer
from QEfficient.utils.logging_utils import logger

Expand All @@ -30,6 +33,11 @@ class QEFFTransformersBase(QEFFBaseModel):
"""

def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwargs) -> None:
if hasattr(model.config, "quantization_config") and not isinstance(
model.config.quantization_config, tuple(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.values())
):
raise AssertionError("Please use `from_pretrained` method to load quantized models")

super().__init__(model)
self.model.config.use_cache = (
True # Always pass use_cache = True, to get KV values as output during ONNX export
Expand All @@ -53,6 +61,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}\n" + self.model.__repr__()

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
"""
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM.
Expand Down Expand Up @@ -92,6 +101,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
logger.warning(f"Updating attn_implementation to be 'eager', got {attn_implementation}")
kwargs.update({"attn_implementation": "eager"})

if low_cpu_mem_usage := kwargs.get("low_cpu_mem_usage", None):
logger.warning(f"Updating low_cpu_mem_usage to be 'False', got {low_cpu_mem_usage}")
kwargs.update({"low_cpu_mem_usage": False})

model = QEFFAutoModelToTransformersAutoModelMap[cls.__name__].from_pretrained(
pretrained_model_name_or_path, *args, **kwargs
)
Expand Down Expand Up @@ -148,9 +161,15 @@ def transform(self, **kwargs):
"""
if self.is_transformed:
return

if kwargs.get("full_batch_size", None):
self._pytorch_transforms.remove(KVCacheTransform)
self._pytorch_transforms.append(CBTransform)

# Update list of pytorch transforms if the model falls in AWQ/GPTQ category
if isinstance(self.model.config.quantization_config, QEffAwqConfig):
self._pytorch_transforms.insert(0, AwqToMatmulNbitsTransform)

for transform in self._pytorch_transforms:
transform.apply(self.model)
self.is_transformed = True
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/quantizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
41 changes: 41 additions & 0 deletions QEfficient/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING

from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer

QEFF_AUTO_QUANTIZER_MAPPING = {"awq": QEffAwqQuantizer}

QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": QEffAwqConfig}


def with_replaced_quantizers(func):
def wrapper(*args, **kwargs):
transformers_replaced_quantization_config_mapping = dict()
transformers_replaced_quantizer_mapping = dict()

for k in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
# Replace quantization config
transformers_replaced_quantization_config_mapping[k] = AUTO_QUANTIZATION_CONFIG_MAPPING[k]
AUTO_QUANTIZATION_CONFIG_MAPPING[k] = QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING[k]

# Replace quantizer
transformers_replaced_quantizer_mapping[k] = AUTO_QUANTIZER_MAPPING[k]
AUTO_QUANTIZER_MAPPING[k] = QEFF_AUTO_QUANTIZER_MAPPING[k]

# Call the function for loading quantized models here
out = func(*args, **kwargs)

# Put back quantization config and quantizer
for k in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
AUTO_QUANTIZATION_CONFIG_MAPPING[k] = transformers_replaced_quantization_config_mapping[k]
AUTO_QUANTIZER_MAPPING[k] = transformers_replaced_quantizer_mapping[k]

return out

return wrapper
Loading