diff --git a/QEfficient/base/common.py b/QEfficient/base/common.py index df7496e15..c96d085e1 100644 --- a/QEfficient/base/common.py +++ b/QEfficient/base/common.py @@ -56,15 +56,7 @@ def get_hf_model_type(hf_model_path: str) -> QEFF_MODEL_TYPE: ) if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING: - # FIXME: Add logic to handle if quantization config is stored in separate quant_config.json outside of config, also create a separate function for this and below lines - quant_config = getattr(config, "quantization_config", getattr(config, "quant_config", None)) - if quant_config is not None: - if quant_config.get("quant_method", None) == "awq": - return QEFF_MODEL_TYPE.AWQ - else: - raise NotImplementedError(f"current model type is not yet supported {type(config)}") - else: - return QEFF_MODEL_TYPE.CAUSALLM + return QEFF_MODEL_TYPE.CAUSALLM else: raise NotImplementedError(f"model type {type(config)} is not yet supported") diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index 25d31d5b3..24e1c8847 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -53,3 +53,35 @@ def register(cls, from_module: Type[nn.Module], to_module: Type[nn.Module]): FlashAttention.register(LLamaAttention, LlamaFlashAttention) """ cls._module_mapping[from_module] = to_module + + +class ModuleMutatorTransform(PytorchTransform): + """Serves as base class for any transform that mutates pytorch module in any way. + Mutate here mean, we initialize a new pytorch module object using info from original module and + replace original module with new module. + + Raises: + NotImplementedError: Not supposed to use directly, Create a subclass and implement mutate method and assign a valid nn.Module class to _match_class variable. + """ + + _match_class: nn.Module + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + transformed = False + for name, module in model.named_children(): + if isinstance(module, cls._match_class): + setattr(model, name, cls.mutate(module, model)) + transformed = True + else: + cls.apply(module) + + if isinstance(model, cls._match_class): + model = cls.mutate(model, None) + transformed = True + + return model, transformed + + @classmethod + def mutate(cls, original_module: nn.Module, parent_module: nn.Module): + raise NotImplementedError("Please implement your own method by inheriting this class") diff --git a/QEfficient/customop/matmulnbits.py b/QEfficient/customop/matmulnbits.py new file mode 100644 index 000000000..a0fad1239 --- /dev/null +++ b/QEfficient/customop/matmulnbits.py @@ -0,0 +1,186 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import math + +import torch +from torch import nn + + +class QuantLinearTorchFunction(torch.autograd.Function): + @staticmethod + def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features): + input_tuple = (x, qself_qweight, qself_scales, qself_qzeros) + input_tuple += (g_idx,) if g_idx is not None else () + return g.op( + "com.microsoft::MatMulNBits", + *input_tuple, + outputs=1, + K_i=in_features, + N_i=out_features, + bits_i=bits, + block_size_i=groupsize, + ) + + @staticmethod + def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features): + if torch.onnx.is_in_onnx_export(): + return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype, device=x.device).float() + fp_weight = dequantize_blockwise_bits( + qself_qweight, qself_scales, qself_qzeros, bits, groupsize, g_idx, in_features, out_features + )[0].float() + + return torch.matmul(x.float(), fp_weight.T.float()) + + +def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, groupsize, g_idx, rows, cols): + if bits != 4: + raise ValueError("Only bits=4 is supported for executing quantized model") + if groupsize != 128: + raise ValueError("Only groupsize=128 is supported for executing quantized model") + expand_quant_value = ( + quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device) + ) & 0x0F + expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1) + aligned_scale = scale.reshape(*quant_values.shape[:-1], 1) + if zero_point.dtype == scale.dtype: + expand_zero_point = zero_point.reshape(*quant_values.shape[:-1], -1) + else: + expand_zero_point = ( + zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device) + ) & 0x0F + try: + expand_zero_point = expand_zero_point.reshape(*quant_values.shape[:-1], -1) + # FIXME: remove try-except + except RuntimeError: + expand_zero_point = expand_zero_point.reshape(quant_values.shape[0], -1, 1) + expand_zero_point = expand_zero_point[:, : quant_values.shape[1]] + if g_idx is not None and g_idx[:32].sum().item() != 0: + float_values = ( + (expand_quant_value.reshape(expand_quant_value.shape[0], -1) - expand_zero_point[:, g_idx, 0]) + * aligned_scale[:, g_idx, 0] + ).to(scale.dtype) + else: + float_values = ((expand_quant_value - expand_zero_point) * aligned_scale).to(scale.dtype) + float_values = float_values.reshape(cols, -1) + if rows != float_values.shape[-1]: + float_values = float_values[:, :rows] + expand_zero_point = expand_zero_point[:, :rows] + if expand_zero_point.ndim == 3: + expand_zero_point = expand_zero_point.squeeze(-1) + if aligned_scale.ndim == 3: + aligned_scale = aligned_scale.squeeze(-1) + + return float_values, expand_zero_point, aligned_scale + + +class QuantLinearORT(nn.Module): + def __init__(self, bits, groupsize, in_features, out_features, bias): + super().__init__() + if bits not in [2, 3, 4, 5, 6, 7, 8]: + raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.") + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else in_features + self.act_order = None + + q_rows = in_features // self.groupsize + self.register_buffer( + "qweight", + torch.zeros((out_features, q_rows, self.groupsize // (8 // bits)), dtype=torch.uint8), + ) + self.register_buffer( + "qzeros", + torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8), + ) + self.register_buffer( + "scales", torch.zeros((math.ceil(in_features / self.groupsize) * out_features), dtype=torch.float16) + ) + self.register_buffer( + "g_idx", torch.tensor([i // self.groupsize for i in range(in_features)], dtype=torch.int32) + ) + if bias: + self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16)) + else: + self.bias = None + + def quant_weight(self, weight, scales, zeros, g_idx): + scale_zeros = zeros * scales + scale_mat = scales[g_idx] + scale_zeros_mat = scale_zeros[g_idx] + int_weight_T = torch.round(((weight + scale_zeros_mat) / scale_mat).float()).to(torch.int) + return int_weight_T + + def pack_on_device(self, int_weight, int_zeros): + if self.bits != 4: + raise ValueError("only 4bit is supported by ONNXRUNTIME for now.") + + # Order of groups + self.act_order = self.g_idx[: self.groupsize // self.bits].sum().item() != 0 + + intzeros_pt = int_zeros.T if int_zeros.dtype == self.scales.dtype else int_zeros.T.byte() + scales_pt = self.scales.T.to(int_weight.device) + intweight_pt = int_weight.byte() + + block_size = self.groupsize + rows, cols = intweight_pt.shape + blob_size = block_size // 2 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0) + intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1] & 1, 0, 0), "constant", 0) + + # Pack zeros if they are not float + if int_zeros.dtype != self.scales.dtype: + intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4) + intzeros_pt = intzeros_pt.reshape(-1) + + # Pack weights + intweight_pt_T = int_weight.T + intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4) + intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size) + + scales_pt = scales_pt.reshape(-1) + + # Validation checks + if (self.qweight.shape != intweight_pt_T.shape) and ( + self.qzeros.shape == intzeros_pt.shape or self.qzeros.dtype != intzeros_pt.dtype + ): + raise RuntimeError("Something went wrong while packing the weights in QuantLinearORT module") + + # Assign buffers + self.scales = scales_pt.float() + self.qweight = intweight_pt_T.byte() # Convert to uint8 + if int_zeros.dtype != self.scales.dtype: + self.qzeros = intzeros_pt.byte() # Convert to uint8 + else: + self.qzeros = intzeros_pt + + def pack(self, linear, scales, zeros, g_idx=None): + layer_weight = linear.weight.data + self.scales = scales.T + self.g_idx = g_idx.clone() + int_weight = self.quant_weight(layer_weight.T, scales.T, zeros.T, g_idx) + return self.pack_on_device(int_weight, zeros.T) + + def forward(self, inputs): + out = QuantLinearTorchFunction().apply( + inputs, + self.qweight, + self.scales, + self.qzeros, + self.g_idx if self.act_order else None, + self.bits, + self.groupsize, + self.in_features, + self.out_features, + ) + out = out + self.bias if self.bias is not None else out + return out diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index fa15fa8e9..6cc653ba9 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -14,6 +14,9 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel, Runtime from QEfficient.transformers.pytorch_transforms import CBTransform, CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers +from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform +from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig from QEfficient.utils import get_qpc_dir_path, load_hf_tokenizer from QEfficient.utils.logging_utils import logger @@ -30,6 +33,11 @@ class QEFFTransformersBase(QEFFBaseModel): """ def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwargs) -> None: + if hasattr(model.config, "quantization_config") and not isinstance( + model.config.quantization_config, tuple(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.values()) + ): + raise AssertionError("Please use `from_pretrained` method to load quantized models") + super().__init__(model) self.model.config.use_cache = ( True # Always pass use_cache = True, to get KV values as output during ONNX export @@ -53,6 +61,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}\n" + self.model.__repr__() @classmethod + @with_replaced_quantizers def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. @@ -92,6 +101,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): logger.warning(f"Updating attn_implementation to be 'eager', got {attn_implementation}") kwargs.update({"attn_implementation": "eager"}) + if low_cpu_mem_usage := kwargs.get("low_cpu_mem_usage", None): + logger.warning(f"Updating low_cpu_mem_usage to be 'False', got {low_cpu_mem_usage}") + kwargs.update({"low_cpu_mem_usage": False}) + model = QEFFAutoModelToTransformersAutoModelMap[cls.__name__].from_pretrained( pretrained_model_name_or_path, *args, **kwargs ) @@ -148,9 +161,15 @@ def transform(self, **kwargs): """ if self.is_transformed: return + if kwargs.get("full_batch_size", None): self._pytorch_transforms.remove(KVCacheTransform) self._pytorch_transforms.append(CBTransform) + + # Update list of pytorch transforms if the model falls in AWQ/GPTQ category + if isinstance(self.model.config.quantization_config, QEffAwqConfig): + self._pytorch_transforms.insert(0, AwqToMatmulNbitsTransform) + for transform in self._pytorch_transforms: transform.apply(self.model) self.is_transformed = True diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py new file mode 100644 index 000000000..d259e435a --- /dev/null +++ b/QEfficient/transformers/quantizers/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/quantizers/auto.py b/QEfficient/transformers/quantizers/auto.py new file mode 100644 index 000000000..e9b132126 --- /dev/null +++ b/QEfficient/transformers/quantizers/auto.py @@ -0,0 +1,41 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING + +from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer + +QEFF_AUTO_QUANTIZER_MAPPING = {"awq": QEffAwqQuantizer} + +QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": QEffAwqConfig} + + +def with_replaced_quantizers(func): + def wrapper(*args, **kwargs): + transformers_replaced_quantization_config_mapping = dict() + transformers_replaced_quantizer_mapping = dict() + + for k in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): + # Replace quantization config + transformers_replaced_quantization_config_mapping[k] = AUTO_QUANTIZATION_CONFIG_MAPPING[k] + AUTO_QUANTIZATION_CONFIG_MAPPING[k] = QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING[k] + + # Replace quantizer + transformers_replaced_quantizer_mapping[k] = AUTO_QUANTIZER_MAPPING[k] + AUTO_QUANTIZER_MAPPING[k] = QEFF_AUTO_QUANTIZER_MAPPING[k] + + # Call the function for loading quantized models here + out = func(*args, **kwargs) + + # Put back quantization config and quantizer + for k in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): + AUTO_QUANTIZATION_CONFIG_MAPPING[k] = transformers_replaced_quantization_config_mapping[k] + AUTO_QUANTIZER_MAPPING[k] = transformers_replaced_quantizer_mapping[k] + + return out + + return wrapper diff --git a/QEfficient/transformers/quantizers/awq.py b/QEfficient/transformers/quantizers/awq.py new file mode 100644 index 000000000..7875efdde --- /dev/null +++ b/QEfficient/transformers/quantizers/awq.py @@ -0,0 +1,131 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn + + +class WQLinear_GEMM(nn.Module): + def __init__(self, w_bit, group_size, in_features, out_features, bias): + super().__init__() + + if w_bit != 4: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else in_features + + # quick sanity check (make sure alignment) + if self.in_features % self.group_size != 0: + raise ValueError( + f"in_features should be perfectly divisible by group_size, got in_features = {self.in_features}, group_size = {self.group_size} while initializing WQLinear_GEMM module" + ) + if out_features % (32 // self.w_bit) != 0: + raise ValueError( + f"out_features must be perfectly divisible by number of weights packed into int32 value i.e. 8, got out_features={self.out_features}" + ) + + # For compatibility with QuantLinearORT + self.g_idx = torch.tensor([i // group_size for i in range(in_features)], dtype=torch.int32) + self.register_buffer( + "qweight", + torch.zeros( + (in_features, out_features // (32 // self.w_bit)), + dtype=torch.int32, + ), + ) + self.register_buffer( + "qzeros", + torch.zeros( + (in_features // self.group_size, out_features // (32 // self.w_bit)), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + (in_features // self.group_size, out_features), + dtype=torch.float16, + ), + ) + if bias: + self.register_buffer( + "bias", + torch.zeros( + (out_features), + dtype=torch.float16, + ), + ) + else: + self.bias = None + + def forward(self, x): + # Only Inference supported + with torch.no_grad(): + out_shape = x.shape[:-1] + (self.out_features,) + + out = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size) + out = torch.matmul(x.float(), out.float()) + + out = out + self.bias if self.bias is not None else out + out = out.reshape(out_shape) + + return out + + +def unpack_and_reverse_weights_and_zeros(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits) + + # unpacking weights column-wise + int_weights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + int_weights = int_weights.view(int_weights.shape[0], -1) + + # unpacking zeros column-wise + int_zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + int_zeros = int_zeros.view(int_zeros.shape[0], -1) + + reverse_order_tensor = torch.arange( + int_weights.shape[-1], + dtype=torch.int32, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, [0, 4, 1, 5, 2, 6, 3, 7]] + reverse_order_tensor = reverse_order_tensor.view(-1) + + int_zeros = int_zeros[:, reverse_order_tensor] + int_weights = int_weights[:, reverse_order_tensor] + + return int_weights, int_zeros + + +def unpack_awq_weights(qweight, qzeros, scales, bits): + int_weight, int_zeros = unpack_and_reverse_weights_and_zeros(qweight, qzeros, bits) + + # overflow checks + int_weight = torch.bitwise_and(int_weight, (2**bits) - 1) + int_zeros = torch.bitwise_and(int_zeros, (2**bits) - 1) + + return scales, int_weight, int_zeros + + +def dequantize_gemm(qweight, qzeros, scales, bits, group_size): + # Unpack the qweight and qzeros tensors + scales, int_weight, int_zeros = unpack_awq_weights(qweight, qzeros, scales, bits) + + # fp16 weights + scales = scales.repeat_interleave(group_size, dim=0) + int_zeros = int_zeros.repeat_interleave(group_size, dim=0) + + int_weight = (int_weight - int_zeros) * scales + + return int_weight diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py new file mode 100644 index 000000000..49859d528 --- /dev/null +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -0,0 +1,51 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +from torch import nn + +from QEfficient.base.pytorch_transforms import ModuleMutatorTransform +from QEfficient.customop.matmulnbits import QuantLinearORT +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM, unpack_awq_weights + + +class AwqToMatmulNbitsTransform(ModuleMutatorTransform): + _match_class = WQLinear_GEMM + + @staticmethod + def unpack_and_dequantize_awq(qweight, qzeros, scales, bits, group_size): + # Unpack the qweight and qzeros tensors + scales, int_weight, int_zeros = unpack_awq_weights(qweight, qzeros, scales, bits) + + # fp16 weights + scales_expand = scales.repeat_interleave(group_size, dim=0) + int_zeros_expand = int_zeros.repeat_interleave(group_size, dim=0) + int_weight = (int_weight - int_zeros_expand) * scales_expand + + return int_weight.T, scales, int_zeros.to(torch.int32) + + @classmethod + def mutate(cls, original_module: nn.Module, parent_module: nn.Module): + fp16_weight, scales, zeros = cls.unpack_and_dequantize_awq( + original_module.qweight, + original_module.qzeros, + original_module.scales, + original_module.w_bit, + original_module.group_size, + ) + + original_module.weight = fp16_weight + new_module = QuantLinearORT( + original_module.w_bit, + original_module.group_size, + original_module.in_features, + original_module.out_features, + original_module.bias is not None, + ) + new_module.bias = original_module.bias if original_module.bias is not None else None + new_module.pack(original_module, scales.T, zeros.T, original_module.g_idx) + return new_module diff --git a/QEfficient/transformers/quantizers/quantizer_awq.py b/QEfficient/transformers/quantizers/quantizer_awq.py new file mode 100644 index 000000000..2715b9bfb --- /dev/null +++ b/QEfficient/transformers/quantizers/quantizer_awq.py @@ -0,0 +1,198 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy + +import torch +import torch.nn as nn +from transformers.integrations.awq import AWQ_SCALES_MAPPINGS +from transformers.quantizers.quantizer_awq import AwqQuantizer +from transformers.utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion + +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.utils.logging_utils import logger + + +class QEffAwqConfig(AwqConfig): + def post_init(self): + """ + Safety checker that arguments are correct + """ + + if self.backend not in [AwqBackendPackingMethod.AUTOAWQ]: + raise ValueError( + f"Only quantization backend {AwqBackendPackingMethod.AUTOAWQ} is supported - not recognized backend {self.backend}" + ) + + self.version = AWQLinearVersion.from_str(self.version) + if self.version not in [AWQLinearVersion.GEMM]: + raise ValueError( + f"Only {AWQLinearVersion.GEMM} version in supported - not recognized version {self.version}" + ) + + if self.do_fuse or self.fuse_max_seq_len is not None: + raise ValueError( + f"fused modules are not supported, got do_fuse={self.do_fuse}, fuse_max_seq_len={self.fuse_max_seq_len}" + ) + + if self.bits != 4: + raise ValueError(f"Only 4-bit AWQ quantization is supported, got bits={self.bits}") + + +class QEffAwqQuantizer(AwqQuantizer): + def __init__(self, quantization_config: QEffAwqConfig, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, device_map, **kwargs): + # No need to validate as we will always use pytorch CPU version. + return True + + @property + def is_trainable(self): + return False + + def update_torch_dtype(self, torch_dtype): + if torch_dtype not in [None, torch.float32]: + logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") + return None + + def _process_model_before_weight_loading(self, model, **kwargs): + self.modules_to_not_convert = get_keys_to_not_convert(model) + + if self.quantization_config.modules_to_not_convert is not None: + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + + model, has_been_replaced = replace_linear_layer_with_awq_gemm( + model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert + ) + + model = replace_quantization_scales(model, model.config.model_type) + if not has_been_replaced: + logger.warning( + "You are loading an AWQ model but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is a bug." + ) + + +class ScaledActivation(nn.Module): + def __init__(self, module, scales): + super().__init__() + self.act = module + self.scales = nn.Parameter(scales.data) + + def forward(self, x): + return self.act(x) / self.scales.view(1, 1, -1) + + +def replace_quantization_scales(model, model_type): + if model_type not in AWQ_SCALES_MAPPINGS: + return model + for name, module in model.named_children(): + act_name = AWQ_SCALES_MAPPINGS[model_type]["act"] + layer_before_act_name = AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"] + if name == act_name and hasattr(model, layer_before_act_name): + layer_before_act = getattr(model, AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"]) + size = layer_before_act.out_features + scale_like = torch.ones(size) + model._modules[name] = ScaledActivation(module, scale_like) + replace_quantization_scales(module, model_type) + return model + + +def replace_linear_layer_with_awq_gemm( + model: torch.nn.Module, + quantization_config=None, + modules_to_not_convert=None, + current_key_name=None, + has_been_replaced=False, +): + modules_to_not_convert = modules_to_not_convert if modules_to_not_convert else [] + + for name, module in model.named_children(): + current_key_name = current_key_name if current_key_name else [] + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + model._modules[name] = WQLinear_GEMM( + w_bit=quantization_config.bits, + group_size=quantization_config.group_size, + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + ) + has_been_replaced = True + + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + + if len(list(module.children())) > 0: + _, has_been_replaced = replace_linear_layer_with_awq_gemm( + module, + modules_to_not_convert=modules_to_not_convert, + current_key_name=current_key_name, + quantization_config=quantization_config, + has_been_replaced=has_been_replaced, + ) + + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def get_keys_to_not_convert(model): + tied_model = copy.deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model.tie_weights() + tied_params = find_tied_parameters(tied_model) + tied_keys = sum(tied_params, []) + + # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision + if len(tied_keys) == 0: + output_emb = model.get_output_embeddings() + if output_emb is not None: + list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)] + return list_last_module + + # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision + list_modules = list(model.named_parameters()) + list_last_module = [list_modules[-1][0]] + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = list(set(tied_keys)) + list(intersection) + + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names + + +def find_tied_parameters(model: nn.Module, named_parameters=None, prefix="", result={}): + if named_parameters is None: + named_parameters = {n: p for n, p in model.named_parameters()} + else: + for name, parameter in model.named_parameters(): + full_name = name if prefix == "" else f"{prefix}.{name}" + if full_name not in named_parameters: + # When we find one, it has to be one of the existing parameters. + for new_name, new_param in named_parameters.items(): + if new_param is parameter: + if new_name not in result: + result[new_name] = [] + result[new_name].append(full_name) + + # Once we have treated direct parameters, we move to the child modules. + for name, child in model.named_children(): + child_name = name if prefix == "" else f"{prefix}.{name}" + find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result) + + return [sorted([weight] + list(set(tied))) for weight, tied in result.items()] diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index d20a4bebd..8a9e3d1cc 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -199,7 +199,7 @@ def check_and_assign_cache_dir(local_model_dir, cache_dir): def padding_check_and_fix(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]) -> None: """ - Checks and fixes tokenizer paddding side and pad_token_id viability. + Checks and fixes tokenizer padding side and pad_token_id viability. -------- tokenizer: `Union[PreTrainedTokenizer, PreTrainedTokenizerFast]` - Pass model tokenizer to check and fix. @@ -251,7 +251,7 @@ def get_padding_shape_from_config(config, batch_size, seq_len): n_heads = config.num_attention_heads d_head = config.hidden_size // config.num_attention_heads else: - raise ValueError("Invalid model configuration: n_head/n_heads or num_key_value_heads not found.") + raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.") padding_shape = [batch_size, n_heads, seq_len, d_head] return padding_shape diff --git a/tests/base/test_pytorch_transforms.py b/tests/base/test_pytorch_transforms.py index 981977c77..764bb887a 100644 --- a/tests/base/test_pytorch_transforms.py +++ b/tests/base/test_pytorch_transforms.py @@ -9,7 +9,20 @@ import torch from torch import nn -from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMutatorTransform + + +class TestModel(nn.Module): + def __init__(self): + super().__init__() + + self.a = nn.Linear(32, 64) + self.b = nn.Linear(64, 32) + + def forward(self, x): + x = self.a(x) + x = self.b(x) + return x def test_module_mapping_transform(): @@ -19,24 +32,35 @@ def test_module_mapping_transform(): class TestTransform(ModuleMappingTransform): _module_mapping = {nn.Linear: nn.Identity} - class TestModel(nn.Module): - def __init__(self): - super().__init__() + model = TestModel() + x = torch.rand(1, 32) + y1 = model(x) + assert torch.any(y1 != x) + + model, transformed = TestTransform.apply(model) + assert transformed + y2 = model(x) + assert torch.all(y2 == x) - self.a = nn.Linear(32, 64) - self.b = nn.Linear(64, 32) - def forward(self, x): - x = self.a(x) - x = self.b(x) - return x +def test_module_mutator_transform(): + with pytest.raises(TypeError): + ModuleMutatorTransform() + + class TestTransform(ModuleMutatorTransform): + _match_class = nn.Linear + + @classmethod + def mutate(cls, original_module: nn.Module, parent_module: nn.Module): + return nn.Identity() model = TestModel() + prev_ids = [id(model.a), id(model.b)] x = torch.rand(1, 32) y1 = model(x) assert torch.any(y1 != x) - model, transformed = TestTransform.apply(model) assert transformed + assert not ([id(model.a), id(model.b)] == prev_ids) y2 = model(x) assert torch.all(y2 == x) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 9b630d0c2..c87a07085 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -15,7 +15,7 @@ from QEfficient.utils.constants import Constants from QEfficient.utils.device_utils import get_available_device_id from QEfficient.utils.run_utils import ApiRunner -from tests.utils import load_pytorch_model +from tests.utils import load_pytorch_model, replace_transformers_quantizers test_models = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", @@ -30,6 +30,7 @@ "wtang06/mpt-125m-c4", "hakurei/gpt-j-random-tinier", "mistralai/Mixtral-8x7B-Instruct-v0.1", + "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model ] @@ -40,6 +41,7 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): Test function to validate the model before and after KV changes on Pytorch :param model_name: Name of model. """ + replace_transformers_quantizers() if model_name == "microsoft/Phi-3-mini-4k-instruct": n_layer = 2 # test only 2 layer models else: diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index a7ded6eac..e562748aa 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -9,7 +9,10 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM +from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform from QEfficient.utils._utils import get_padding_shape_from_config from QEfficient.utils.logging_utils import logger @@ -188,3 +191,25 @@ def test_kv_cache_transform( input_len=8, logits_tolerance=logits_tolerance, ) + + +@pytest.mark.parametrize("in_features", [2048, 4096]) +@pytest.mark.parametrize("out_features", [2048, 4096]) +def test_awq_to_matmulnbits_transform(in_features, out_features): + wqlinear = WQLinear_GEMM(w_bit=4, group_size=128, in_features=in_features, out_features=out_features, bias=False) + + wqlinear.qweight = torch.randint( + low=-(2**31), high=2**31 - 1, size=(in_features, out_features // 8), dtype=torch.int32 + ) + wqlinear.qzeros = torch.randint( + low=-(2**31), high=2**31 - 1, size=(in_features // wqlinear.group_size, out_features // 8), dtype=torch.int32 + ) + wqlinear.scales = torch.rand(in_features // wqlinear.group_size, out_features, dtype=torch.float32) + + rand_data = torch.rand(4, in_features) + old_out = wqlinear(rand_data) + new_module, transformed = AwqToMatmulNbitsTransform.apply(wqlinear) + assert transformed + new_out = new_module(rand_data) + assert isinstance(new_module, QuantLinearORT) + compare_original_vs_kv_model_pt_outputs(old_out, new_out, tolerance=1e-8) diff --git a/tests/utils.py b/tests/utils.py index 5f7433969..667a30d89 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,7 +9,9 @@ import unittest from transformers import AutoModelForCausalLM +from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING +from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer from QEfficient.utils import hf_download from QEfficient.utils.device_utils import is_multi_qranium_setup_available @@ -48,3 +50,8 @@ def load_pytorch_model(model_config): params = sum(p.numel() for p in model_hf.parameters()) model_hf.eval() return model_hf, params + + +def replace_transformers_quantizers(): + AUTO_QUANTIZER_MAPPING.update({"awq": QEffAwqQuantizer}) + AUTO_QUANTIZATION_CONFIG_MAPPING.update({"awq": QEffAwqConfig})