Skip to content

Commit 6a2c1a0

Browse files
committed
Awq feature (#100)
* added preprocess layer before loading quantized awq weights Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com> * added onnx export Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com> * added ScaledActivation class Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com> * refactoring the code to right places and added one single test for now Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com> * cleaned code Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com> * added proper tests, added decorator for updating quantizers, cleaned code Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com> * fixed CLI Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com> * added auto file for decorator Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com> --------- Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>
1 parent 0ef6829 commit 6a2c1a0

File tree

14 files changed

+737
-23
lines changed

14 files changed

+737
-23
lines changed

QEfficient/base/common.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,7 @@ def get_hf_model_type(hf_model_path: str) -> QEFF_MODEL_TYPE:
5656
)
5757

5858
if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING:
59-
# FIXME: Add logic to handle if quantization config is stored in separate quant_config.json outside of config, also create a separate function for this and below lines
60-
quant_config = getattr(config, "quantization_config", getattr(config, "quant_config", None))
61-
if quant_config is not None:
62-
if quant_config.get("quant_method", None) == "awq":
63-
return QEFF_MODEL_TYPE.AWQ
64-
else:
65-
raise NotImplementedError(f"current model type is not yet supported {type(config)}")
66-
else:
67-
return QEFF_MODEL_TYPE.CAUSALLM
59+
return QEFF_MODEL_TYPE.CAUSALLM
6860
else:
6961
raise NotImplementedError(f"model type {type(config)} is not yet supported")
7062

QEfficient/base/pytorch_transforms.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,35 @@ def register(cls, from_module: Type[nn.Module], to_module: Type[nn.Module]):
5555
FlashAttention.register(LLamaAttention, LlamaFlashAttention)
5656
"""
5757
cls._module_mapping[from_module] = to_module
58+
59+
60+
class ModuleMutatorTransform(PytorchTransform):
61+
"""Serves as base class for any transform that mutates pytorch module in any way.
62+
Mutate here mean, we initialize a new pytorch module object using info from original module and
63+
replace original module with new module.
64+
65+
Raises:
66+
NotImplementedError: Not supposed to use directly, Create a subclass and implement mutate method and assign a valid nn.Module class to _match_class variable.
67+
"""
68+
69+
_match_class: nn.Module
70+
71+
@classmethod
72+
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
73+
transformed = False
74+
for name, module in model.named_children():
75+
if isinstance(module, cls._match_class):
76+
setattr(model, name, cls.mutate(module, model))
77+
transformed = True
78+
else:
79+
cls.apply(module)
80+
81+
if isinstance(model, cls._match_class):
82+
model = cls.mutate(model, None)
83+
transformed = True
84+
85+
return model, transformed
86+
87+
@classmethod
88+
def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
89+
raise NotImplementedError("Please implement your own method by inheriting this class")

QEfficient/customop/matmulnbits.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import math
9+
10+
import torch
11+
from torch import nn
12+
13+
14+
class QuantLinearTorchFunction(torch.autograd.Function):
15+
@staticmethod
16+
def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features):
17+
input_tuple = (x, qself_qweight, qself_scales, qself_qzeros)
18+
input_tuple += (g_idx,) if g_idx is not None else ()
19+
return g.op(
20+
"com.microsoft::MatMulNBits",
21+
*input_tuple,
22+
outputs=1,
23+
K_i=in_features,
24+
N_i=out_features,
25+
bits_i=bits,
26+
block_size_i=groupsize,
27+
)
28+
29+
@staticmethod
30+
def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features):
31+
if torch.onnx.is_in_onnx_export():
32+
return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype, device=x.device).float()
33+
fp_weight = dequantize_blockwise_bits(
34+
qself_qweight, qself_scales, qself_qzeros, bits, groupsize, g_idx, in_features, out_features
35+
)[0].float()
36+
37+
return torch.matmul(x.float(), fp_weight.T.float())
38+
39+
40+
def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, groupsize, g_idx, rows, cols):
41+
if bits != 4:
42+
raise ValueError("Only bits=4 is supported for executing quantized model")
43+
if groupsize != 128:
44+
raise ValueError("Only groupsize=128 is supported for executing quantized model")
45+
expand_quant_value = (
46+
quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device)
47+
) & 0x0F
48+
expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1)
49+
aligned_scale = scale.reshape(*quant_values.shape[:-1], 1)
50+
if zero_point.dtype == scale.dtype:
51+
expand_zero_point = zero_point.reshape(*quant_values.shape[:-1], -1)
52+
else:
53+
expand_zero_point = (
54+
zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device)
55+
) & 0x0F
56+
try:
57+
expand_zero_point = expand_zero_point.reshape(*quant_values.shape[:-1], -1)
58+
# FIXME: remove try-except
59+
except RuntimeError:
60+
expand_zero_point = expand_zero_point.reshape(quant_values.shape[0], -1, 1)
61+
expand_zero_point = expand_zero_point[:, : quant_values.shape[1]]
62+
if g_idx is not None and g_idx[:32].sum().item() != 0:
63+
float_values = (
64+
(expand_quant_value.reshape(expand_quant_value.shape[0], -1) - expand_zero_point[:, g_idx, 0])
65+
* aligned_scale[:, g_idx, 0]
66+
).to(scale.dtype)
67+
else:
68+
float_values = ((expand_quant_value - expand_zero_point) * aligned_scale).to(scale.dtype)
69+
float_values = float_values.reshape(cols, -1)
70+
if rows != float_values.shape[-1]:
71+
float_values = float_values[:, :rows]
72+
expand_zero_point = expand_zero_point[:, :rows]
73+
if expand_zero_point.ndim == 3:
74+
expand_zero_point = expand_zero_point.squeeze(-1)
75+
if aligned_scale.ndim == 3:
76+
aligned_scale = aligned_scale.squeeze(-1)
77+
78+
return float_values, expand_zero_point, aligned_scale
79+
80+
81+
class QuantLinearORT(nn.Module):
82+
def __init__(self, bits, groupsize, in_features, out_features, bias):
83+
super().__init__()
84+
if bits not in [2, 3, 4, 5, 6, 7, 8]:
85+
raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.")
86+
self.in_features = in_features
87+
self.out_features = out_features
88+
self.bits = bits
89+
self.groupsize = groupsize if groupsize != -1 else in_features
90+
self.act_order = None
91+
92+
q_rows = in_features // self.groupsize
93+
self.register_buffer(
94+
"qweight",
95+
torch.zeros((out_features, q_rows, self.groupsize // (8 // bits)), dtype=torch.uint8),
96+
)
97+
self.register_buffer(
98+
"qzeros",
99+
torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8),
100+
)
101+
self.register_buffer(
102+
"scales", torch.zeros((math.ceil(in_features / self.groupsize) * out_features), dtype=torch.float16)
103+
)
104+
self.register_buffer(
105+
"g_idx", torch.tensor([i // self.groupsize for i in range(in_features)], dtype=torch.int32)
106+
)
107+
if bias:
108+
self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16))
109+
else:
110+
self.bias = None
111+
112+
def quant_weight(self, weight, scales, zeros, g_idx):
113+
scale_zeros = zeros * scales
114+
scale_mat = scales[g_idx]
115+
scale_zeros_mat = scale_zeros[g_idx]
116+
int_weight_T = torch.round(((weight + scale_zeros_mat) / scale_mat).float()).to(torch.int)
117+
return int_weight_T
118+
119+
def pack_on_device(self, int_weight, int_zeros):
120+
if self.bits != 4:
121+
raise ValueError("only 4bit is supported by ONNXRUNTIME for now.")
122+
123+
# Order of groups
124+
self.act_order = self.g_idx[: self.groupsize // self.bits].sum().item() != 0
125+
126+
intzeros_pt = int_zeros.T if int_zeros.dtype == self.scales.dtype else int_zeros.T.byte()
127+
scales_pt = self.scales.T.to(int_weight.device)
128+
intweight_pt = int_weight.byte()
129+
130+
block_size = self.groupsize
131+
rows, cols = intweight_pt.shape
132+
blob_size = block_size // 2
133+
k_blocks = (rows + block_size - 1) // block_size
134+
padded_rows = k_blocks * block_size
135+
pad_len = padded_rows - rows
136+
if pad_len > 0:
137+
intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0)
138+
intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1] & 1, 0, 0), "constant", 0)
139+
140+
# Pack zeros if they are not float
141+
if int_zeros.dtype != self.scales.dtype:
142+
intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4)
143+
intzeros_pt = intzeros_pt.reshape(-1)
144+
145+
# Pack weights
146+
intweight_pt_T = int_weight.T
147+
intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4)
148+
intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size)
149+
150+
scales_pt = scales_pt.reshape(-1)
151+
152+
# Validation checks
153+
if (self.qweight.shape != intweight_pt_T.shape) and (
154+
self.qzeros.shape == intzeros_pt.shape or self.qzeros.dtype != intzeros_pt.dtype
155+
):
156+
raise RuntimeError("Something went wrong while packing the weights in QuantLinearORT module")
157+
158+
# Assign buffers
159+
self.scales = scales_pt.float()
160+
self.qweight = intweight_pt_T.byte() # Convert to uint8
161+
if int_zeros.dtype != self.scales.dtype:
162+
self.qzeros = intzeros_pt.byte() # Convert to uint8
163+
else:
164+
self.qzeros = intzeros_pt
165+
166+
def pack(self, linear, scales, zeros, g_idx=None):
167+
layer_weight = linear.weight.data
168+
self.scales = scales.T
169+
self.g_idx = g_idx.clone()
170+
int_weight = self.quant_weight(layer_weight.T, scales.T, zeros.T, g_idx)
171+
return self.pack_on_device(int_weight, zeros.T)
172+
173+
def forward(self, inputs):
174+
out = QuantLinearTorchFunction().apply(
175+
inputs,
176+
self.qweight,
177+
self.scales,
178+
self.qzeros,
179+
self.g_idx if self.act_order else None,
180+
self.bits,
181+
self.groupsize,
182+
self.in_features,
183+
self.out_features,
184+
)
185+
out = out + self.bias if self.bias is not None else out
186+
return out

QEfficient/transformers/models/modeling_auto.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import QEfficient
1515
from QEfficient.base.modeling_qeff import QEFFBaseModel, Runtime
1616
from QEfficient.transformers.pytorch_transforms import CBTransform, CustomOpsTransform, KVCacheTransform
17+
from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers
18+
from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform
19+
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig
1720
from QEfficient.utils import get_qpc_dir_path, load_hf_tokenizer
1821
from QEfficient.utils.logging_utils import logger
1922

@@ -30,6 +33,11 @@ class QEFFTransformersBase(QEFFBaseModel):
3033
"""
3134

3235
def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwargs) -> None:
36+
if hasattr(model.config, "quantization_config") and not isinstance(
37+
model.config.quantization_config, tuple(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.values())
38+
):
39+
raise AssertionError("Please use `from_pretrained` method to load quantized models")
40+
3341
super().__init__(model)
3442
self.model.config.use_cache = (
3543
True # Always pass use_cache = True, to get KV values as output during ONNX export
@@ -53,6 +61,7 @@ def __repr__(self) -> str:
5361
return f"{self.__class__.__name__}\n" + self.model.__repr__()
5462

5563
@classmethod
64+
@with_replaced_quantizers
5665
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
5766
"""
5867
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM.
@@ -92,6 +101,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
92101
logger.warning(f"Updating attn_implementation to be 'eager', got {attn_implementation}")
93102
kwargs.update({"attn_implementation": "eager"})
94103

104+
if low_cpu_mem_usage := kwargs.get("low_cpu_mem_usage", None):
105+
logger.warning(f"Updating low_cpu_mem_usage to be 'False', got {low_cpu_mem_usage}")
106+
kwargs.update({"low_cpu_mem_usage": False})
107+
95108
model = QEFFAutoModelToTransformersAutoModelMap[cls.__name__].from_pretrained(
96109
pretrained_model_name_or_path, *args, **kwargs
97110
)
@@ -148,9 +161,15 @@ def transform(self, **kwargs):
148161
"""
149162
if self.is_transformed:
150163
return
164+
151165
if kwargs.get("full_batch_size", None):
152166
self._pytorch_transforms.remove(KVCacheTransform)
153167
self._pytorch_transforms.append(CBTransform)
168+
169+
# Update list of pytorch transforms if the model falls in AWQ/GPTQ category
170+
if isinstance(self.model.config.quantization_config, QEffAwqConfig):
171+
self._pytorch_transforms.insert(0, AwqToMatmulNbitsTransform)
172+
154173
for transform in self._pytorch_transforms:
155174
transform.apply(self.model)
156175
self.is_transformed = True
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+
8+
from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING
9+
10+
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer
11+
12+
QEFF_AUTO_QUANTIZER_MAPPING = {"awq": QEffAwqQuantizer}
13+
14+
QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": QEffAwqConfig}
15+
16+
17+
def with_replaced_quantizers(func):
18+
def wrapper(*args, **kwargs):
19+
transformers_replaced_quantization_config_mapping = dict()
20+
transformers_replaced_quantizer_mapping = dict()
21+
22+
for k in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
23+
# Replace quantization config
24+
transformers_replaced_quantization_config_mapping[k] = AUTO_QUANTIZATION_CONFIG_MAPPING[k]
25+
AUTO_QUANTIZATION_CONFIG_MAPPING[k] = QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING[k]
26+
27+
# Replace quantizer
28+
transformers_replaced_quantizer_mapping[k] = AUTO_QUANTIZER_MAPPING[k]
29+
AUTO_QUANTIZER_MAPPING[k] = QEFF_AUTO_QUANTIZER_MAPPING[k]
30+
31+
# Call the function for loading quantized models here
32+
out = func(*args, **kwargs)
33+
34+
# Put back quantization config and quantizer
35+
for k in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
36+
AUTO_QUANTIZATION_CONFIG_MAPPING[k] = transformers_replaced_quantization_config_mapping[k]
37+
AUTO_QUANTIZER_MAPPING[k] = transformers_replaced_quantizer_mapping[k]
38+
39+
return out
40+
41+
return wrapper

0 commit comments

Comments
 (0)