Skip to content

Commit 6dba3db

Browse files
quic-amitrajochougul
authored andcommitted
Adding support for GPTQ models (#103)
* Adding support for gptq models Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * Code cleaning and formating Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * ruff format and fixed some bug Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * Added tests for gptq Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * Bug-fix-1 Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * fixed bugs-2 Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * fixed bug-3 Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * Added docstring Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * Addressed comments Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * Addressed comments Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * fixed bugs-3 Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * ruff check and format Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> * Addressed comments-3 Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> --------- Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>
1 parent e038e3d commit 6dba3db

File tree

13 files changed

+740
-224
lines changed

13 files changed

+740
-224
lines changed

QEfficient/base/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class QEFF_MODEL_TYPE(Enum):
3131

3232
CAUSALLM = "LLM"
3333
DIFFUSION = "DIFFUSION"
34-
AWQ = "AWQ"
3534

3635

3736
MODEL_TYPE_TO_QEFF_AUTO_MODEL_MAP: Dict[QEFF_MODEL_TYPE, Type[QEFFBaseModel]] = {

QEfficient/customop/matmulnbits.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
class QuantLinearTorchFunction(torch.autograd.Function):
1515
@staticmethod
16-
def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features):
16+
def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features):
1717
input_tuple = (x, qself_qweight, qself_scales, qself_qzeros)
1818
input_tuple += (g_idx,) if g_idx is not None else ()
1919
return g.op(
@@ -23,36 +23,32 @@ def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group
2323
K_i=in_features,
2424
N_i=out_features,
2525
bits_i=bits,
26-
block_size_i=groupsize,
26+
block_size_i=group_size,
2727
)
2828

2929
@staticmethod
30-
def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features):
30+
def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features):
3131
if torch.onnx.is_in_onnx_export():
32-
return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype, device=x.device).float()
32+
return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype).float()
3333
fp_weight = dequantize_blockwise_bits(
34-
qself_qweight, qself_scales, qself_qzeros, bits, groupsize, g_idx, in_features, out_features
34+
qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features
3535
)[0].float()
3636

3737
return torch.matmul(x.float(), fp_weight.T.float())
3838

3939

40-
def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, groupsize, g_idx, rows, cols):
40+
def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, group_size, g_idx, rows, cols):
4141
if bits != 4:
4242
raise ValueError("Only bits=4 is supported for executing quantized model")
43-
if groupsize != 128:
44-
raise ValueError("Only groupsize=128 is supported for executing quantized model")
45-
expand_quant_value = (
46-
quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device)
47-
) & 0x0F
43+
if group_size != 128:
44+
raise ValueError("Only group_size=128 is supported for executing quantized model")
45+
expand_quant_value = (quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F
4846
expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1)
4947
aligned_scale = scale.reshape(*quant_values.shape[:-1], 1)
5048
if zero_point.dtype == scale.dtype:
5149
expand_zero_point = zero_point.reshape(*quant_values.shape[:-1], -1)
5250
else:
53-
expand_zero_point = (
54-
zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device)
55-
) & 0x0F
51+
expand_zero_point = (zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F
5652
try:
5753
expand_zero_point = expand_zero_point.reshape(*quant_values.shape[:-1], -1)
5854
# FIXME: remove try-except
@@ -79,30 +75,30 @@ def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, groupsize,
7975

8076

8177
class QuantLinearORT(nn.Module):
82-
def __init__(self, bits, groupsize, in_features, out_features, bias):
78+
def __init__(self, bits, group_size, in_features, out_features, bias):
8379
super().__init__()
8480
if bits not in [2, 3, 4, 5, 6, 7, 8]:
8581
raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.")
8682
self.in_features = in_features
8783
self.out_features = out_features
8884
self.bits = bits
89-
self.groupsize = groupsize if groupsize != -1 else in_features
85+
self.group_size = group_size if group_size != -1 else in_features
9086
self.act_order = None
9187

92-
q_rows = in_features // self.groupsize
88+
q_rows = in_features // self.group_size
9389
self.register_buffer(
9490
"qweight",
95-
torch.zeros((out_features, q_rows, self.groupsize // (8 // bits)), dtype=torch.uint8),
91+
torch.zeros((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8),
9692
)
9793
self.register_buffer(
9894
"qzeros",
9995
torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8),
10096
)
10197
self.register_buffer(
102-
"scales", torch.zeros((math.ceil(in_features / self.groupsize) * out_features), dtype=torch.float16)
98+
"scales", torch.zeros((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16)
10399
)
104100
self.register_buffer(
105-
"g_idx", torch.tensor([i // self.groupsize for i in range(in_features)], dtype=torch.int32)
101+
"g_idx", torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32)
106102
)
107103
if bias:
108104
self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16))
@@ -121,13 +117,13 @@ def pack_on_device(self, int_weight, int_zeros):
121117
raise ValueError("only 4bit is supported by ONNXRUNTIME for now.")
122118

123119
# Order of groups
124-
self.act_order = self.g_idx[: self.groupsize // self.bits].sum().item() != 0
120+
self.act_order = self.g_idx[: self.group_size // self.bits].sum().item() != 0
125121

126122
intzeros_pt = int_zeros.T if int_zeros.dtype == self.scales.dtype else int_zeros.T.byte()
127123
scales_pt = self.scales.T.to(int_weight.device)
128124
intweight_pt = int_weight.byte()
129125

130-
block_size = self.groupsize
126+
block_size = self.group_size
131127
rows, cols = intweight_pt.shape
132128
blob_size = block_size // 2
133129
k_blocks = (rows + block_size - 1) // block_size
@@ -178,7 +174,7 @@ def forward(self, inputs):
178174
self.qzeros,
179175
self.g_idx if self.act_order else None,
180176
self.bits,
181-
self.groupsize,
177+
self.group_size,
182178
self.in_features,
183179
self.out_features,
184180
)

QEfficient/transformers/models/modeling_auto.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from QEfficient.base.modeling_qeff import QEFFBaseModel, Runtime
1616
from QEfficient.transformers.pytorch_transforms import CBTransform, CustomOpsTransform, KVCacheTransform
1717
from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers
18-
from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform
18+
from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform
1919
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig
20+
from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig
2021
from QEfficient.utils import get_qpc_dir_path, load_hf_tokenizer
2122
from QEfficient.utils.logging_utils import logger
2223

@@ -167,10 +168,12 @@ def transform(self, **kwargs):
167168
self._pytorch_transforms.append(CBTransform)
168169

169170
# Update list of pytorch transforms if the model falls in AWQ/GPTQ category
170-
if hasattr(self.model.config, "quantization_config") and isinstance(
171-
self.model.config.quantization_config, QEffAwqConfig
172-
):
173-
self._pytorch_transforms.insert(0, AwqToMatmulNbitsTransform)
171+
if hasattr(self.model.config, "quantization_config"):
172+
if isinstance(self.model.config.quantization_config, QEffAwqConfig):
173+
self._pytorch_transforms.insert(0, AwqToMatmulNbitsTransform)
174+
175+
if isinstance(self.model.config.quantization_config, QEffGPTQConfig):
176+
self._pytorch_transforms.insert(0, GPTQToMatmulNbitsTransform)
174177

175178
for transform in self._pytorch_transforms:
176179
transform.apply(self.model)

QEfficient/transformers/quantizers/auto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING
99

1010
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer
11+
from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer
1112

12-
QEFF_AUTO_QUANTIZER_MAPPING = {"awq": QEffAwqQuantizer}
13-
14-
QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": QEffAwqConfig}
13+
QEFF_AUTO_QUANTIZER_MAPPING = {"awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer}
14+
QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": QEffAwqConfig, "gptq": QEffGPTQConfig}
1515

1616

1717
def with_replaced_quantizers(func):

QEfficient/transformers/quantizers/awq.py

Lines changed: 9 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,27 @@
88
import torch
99
import torch.nn as nn
1010

11+
from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gemm
12+
1113

1214
class WQLinear_GEMM(nn.Module):
13-
def __init__(self, w_bit, group_size, in_features, out_features, bias):
15+
def __init__(self, bits, group_size, in_features, out_features, bias):
1416
super().__init__()
1517

16-
if w_bit != 4:
18+
if bits != 4:
1719
raise NotImplementedError("Only 4-bit are supported for now.")
1820

1921
self.in_features = in_features
2022
self.out_features = out_features
21-
self.w_bit = w_bit
23+
self.bits = bits
2224
self.group_size = group_size if group_size != -1 else in_features
2325

2426
# quick sanity check (make sure alignment)
2527
if self.in_features % self.group_size != 0:
2628
raise ValueError(
2729
f"in_features should be perfectly divisible by group_size, got in_features = {self.in_features}, group_size = {self.group_size} while initializing WQLinear_GEMM module"
2830
)
29-
if out_features % (32 // self.w_bit) != 0:
31+
if out_features % (32 // self.bits) != 0:
3032
raise ValueError(
3133
f"out_features must be perfectly divisible by number of weights packed into int32 value i.e. 8, got out_features={self.out_features}"
3234
)
@@ -36,14 +38,14 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias):
3638
self.register_buffer(
3739
"qweight",
3840
torch.zeros(
39-
(in_features, out_features // (32 // self.w_bit)),
41+
(in_features, out_features // (32 // self.bits)),
4042
dtype=torch.int32,
4143
),
4244
)
4345
self.register_buffer(
4446
"qzeros",
4547
torch.zeros(
46-
(in_features // self.group_size, out_features // (32 // self.w_bit)),
48+
(in_features // self.group_size, out_features // (32 // self.bits)),
4749
dtype=torch.int32,
4850
),
4951
)
@@ -70,62 +72,10 @@ def forward(self, x):
7072
with torch.no_grad():
7173
out_shape = x.shape[:-1] + (self.out_features,)
7274

73-
out = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size)
75+
out = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size)
7476
out = torch.matmul(x.float(), out.float())
7577

7678
out = out + self.bias if self.bias is not None else out
7779
out = out.reshape(out_shape)
7880

7981
return out
80-
81-
82-
def unpack_and_reverse_weights_and_zeros(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
83-
shifts = torch.arange(0, 32, bits)
84-
85-
# unpacking weights column-wise
86-
int_weights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
87-
torch.int8 # smallest dtype available
88-
)
89-
int_weights = int_weights.view(int_weights.shape[0], -1)
90-
91-
# unpacking zeros column-wise
92-
int_zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
93-
torch.int8 # smallest dtype available
94-
)
95-
int_zeros = int_zeros.view(int_zeros.shape[0], -1)
96-
97-
reverse_order_tensor = torch.arange(
98-
int_weights.shape[-1],
99-
dtype=torch.int32,
100-
)
101-
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
102-
reverse_order_tensor = reverse_order_tensor[:, [0, 4, 1, 5, 2, 6, 3, 7]]
103-
reverse_order_tensor = reverse_order_tensor.view(-1)
104-
105-
int_zeros = int_zeros[:, reverse_order_tensor]
106-
int_weights = int_weights[:, reverse_order_tensor]
107-
108-
return int_weights, int_zeros
109-
110-
111-
def unpack_awq_weights(qweight, qzeros, scales, bits):
112-
int_weight, int_zeros = unpack_and_reverse_weights_and_zeros(qweight, qzeros, bits)
113-
114-
# overflow checks
115-
int_weight = torch.bitwise_and(int_weight, (2**bits) - 1)
116-
int_zeros = torch.bitwise_and(int_zeros, (2**bits) - 1)
117-
118-
return scales, int_weight, int_zeros
119-
120-
121-
def dequantize_gemm(qweight, qzeros, scales, bits, group_size):
122-
# Unpack the qweight and qzeros tensors
123-
scales, int_weight, int_zeros = unpack_awq_weights(qweight, qzeros, scales, bits)
124-
125-
# fp16 weights
126-
scales = scales.repeat_interleave(group_size, dim=0)
127-
int_zeros = int_zeros.repeat_interleave(group_size, dim=0)
128-
129-
int_weight = (int_weight - int_zeros) * scales
130-
131-
return int_weight
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import math
2+
3+
import torch
4+
from torch import nn
5+
6+
from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gptq
7+
8+
9+
class QuantLinearGPTQ(nn.Module):
10+
"""
11+
A quantized linear layer using GPTQ (Generalized Post-Training Quantization).
12+
This class supports only 4-bit quantization and is compatible with QuantLinearORT.
13+
14+
Research paper link- GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers (https://arxiv.org/abs/2210.17323)
15+
16+
Attributes:
17+
in_features (int): The number of input features.
18+
out_features (int): The number of output features.
19+
bits (int): The number of bits used for quantization (must be 4).
20+
act_order (None or bool): The activation order.
21+
orig_fp_weight (None or torch.Tensor): The original floating-point weights.
22+
maxq (int): The maximum quantization value.
23+
group_size (int): The group size for quantization.
24+
pack_mode (str): The packing mode, set to "GPTQ".
25+
qweight (torch.Tensor): The quantized weight tensor.
26+
qzeros (torch.Tensor): The quantized zeros tensor.
27+
scales (torch.Tensor): The scales tensor.
28+
g_idx (torch.Tensor): The group index tensor.
29+
bias (torch.Tensor or None): The bias tensor, if applicable.
30+
"""
31+
32+
def __init__(self, bits, group_size, in_features, out_features, bias):
33+
super().__init__()
34+
if bits != 4:
35+
raise NotImplementedError("Only 4 bits are supported.")
36+
self.in_features = in_features
37+
self.out_features = out_features
38+
self.bits = bits
39+
self.act_order = None
40+
self.orig_fp_weight = None
41+
self.maxq = 2**self.bits - 1
42+
self.group_size = group_size if group_size != -1 else in_features
43+
self.pack_mode = "GPTQ"
44+
45+
# For compatibility with QuantLinearORT
46+
self.register_buffer(
47+
"qweight",
48+
torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32),
49+
)
50+
self.register_buffer(
51+
"qzeros",
52+
torch.zeros((math.ceil(in_features / self.group_size), out_features // 32 * self.bits), dtype=torch.int32),
53+
)
54+
self.register_buffer(
55+
"scales",
56+
torch.zeros((math.ceil(in_features / self.group_size), out_features), dtype=torch.float16),
57+
)
58+
self.g_idx = torch.tensor([i // group_size for i in range(in_features)], dtype=torch.int32)
59+
if bias:
60+
self.register_buffer(
61+
"bias",
62+
torch.zeros((out_features), dtype=torch.float16),
63+
)
64+
else:
65+
self.bias = None
66+
67+
def forward(self, x):
68+
# Only Inference supported
69+
out, _, _ = dequantize_gptq(self.qweight.T, self.qzeros, self.scales, self.bits, self.g_idx)
70+
out = torch.matmul(x.float(), out.float())
71+
out = out + self.bias if self.bias is not None else out
72+
73+
return out

0 commit comments

Comments
 (0)