From 12aeb58e6e6f184a0d23c276566beb4a40c6ab99 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 29 Aug 2025 13:14:59 -0700 Subject: [PATCH] Add hqq support for Int4TilePackedTo4dTensor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: * Added Int4ChooseQparamsAlgorithm enum that has TINYGEMM and HQQ options, by default tensors will be using TINYGEMM option * Enabled `Int4ChooseQparamsAlgorithm.HQQ` option for Int4TilePackedTo4dTensor, instead of calling quant primitive ops for tinygemm to quantize the high precision tensor, the `use_hqq=True` path will quantize with `_choose_qparams_and_quantize_affine_hqq` that help improve accuracy for int4 weight only quantization, but still reuse the tinygemm kernel for speedup Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py Accuracy test (sanity check) to make sure hqq improves accuracy: ``` sh release.sh --model_id Qwen/Qwen3-8B --quants INT4 --push_to_hub no hqq checkpoint: https://huggingface.co/jerryzh168/Qwen3-8B-INT4-non-hqq hqq checkpoint: https://huggingface.co/jerryzh168/Qwen3-8B-INT4 export MODEL=jerryzh168/Qwen3-8B-INT4-non-hqq export TASK=mmlu lm_eval --model hf --model_args pretrained=$MODEL --tasks $TASK --device cuda:0 --batch_size auto | Groups |Version|Filter|n-shot|Metric| |Value | |Stderr| |------------------|------:|------|------|------|---|-----:|---|-----:| |mmlu | 2|none | |acc |↑ |0.7019|± |0.0036| | - humanities | 2|none | |acc |↑ |0.6036|± |0.0066| | - other | 2|none | |acc |↑ |0.7403|± |0.0076| | - social sciences| 2|none | |acc |↑ |0.8083|± |0.0070| | - stem | 2|none | |acc |↑ |0.7069|± |0.0078| export MODEL=jerryzh168/Qwen3-8B-INT4 lm_eval --model hf --model_args pretrained=$MODEL --tasks $TASK --device cuda:0 --batch_size auto | Groups |Version|Filter|n-shot|Metric| |Value | |Stderr| |------------------|------:|------|------|------|---|-----:|---|-----:| |mmlu | 2|none | |acc |↑ |0.7040|± |0.0036| | - humanities | 2|none | |acc |↑ |0.5962|± |0.0065| | - other | 2|none | |acc |↑ |0.7470|± |0.0075| | - social sciences| 2|none | |acc |↑ |0.8177|± |0.0069| | - stem | 2|none | |acc |↑ |0.7114|± |0.0078| hqq improves the accuracy for mmlu slightly. ``` Reviewers: Subscribers: Tasks: Tags: --- .../quantize_and_upload.py | 9 +- test/core/test_config.py | 6 ++ .../test_int4_tile_packed_to_4d_tensor.py | 31 ++++--- torchao/core/config.py | 1 + torchao/quantization/quant_api.py | 50 ++++++---- .../quantize_/workflows/__init__.py | 2 + .../int4/int4_choose_qparams_algorithm.py | 32 +++++++ .../int4/int4_tile_packed_to_4d_tensor.py | 93 +++++++++++++------ 8 files changed, 163 insertions(+), 61 deletions(-) create mode 100644 torchao/quantization/quantize_/workflows/int4/int4_choose_qparams_algorithm.py diff --git a/.github/scripts/torchao_model_releases/quantize_and_upload.py b/.github/scripts/torchao_model_releases/quantize_and_upload.py index 2351f2b3a1..1157c6a9d9 100644 --- a/.github/scripts/torchao_model_releases/quantize_and_upload.py +++ b/.github/scripts/torchao_model_releases/quantize_and_upload.py @@ -205,7 +205,7 @@ def _untie_weights_and_save_locally(model_id): _int4_quant_code = """ from torchao.quantization import Int4WeightOnlyConfig -quant_config = Int4WeightOnlyConfig(group_size=128, use_hqq=True) +quant_config = Int4WeightOnlyConfig(group_size=128, packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2) quantization_config = TorchAoConfig(quant_type=quant_config) quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -627,7 +627,12 @@ def quantize_and_upload( ) quant_to_config = { "FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), - "INT4": Int4WeightOnlyConfig(group_size=128, version=2), + "INT4": Int4WeightOnlyConfig( + group_size=128, + packing_format="tile_packed_to_4d", + int4_choose_qparams_algorithm="hqq", + version=2, + ), "INT8-INT4": ModuleFqnToConfig( { "_default": _int8_int4_linear_config, diff --git a/test/core/test_config.py b/test/core/test_config.py index 9574c3ec76..c7c412f9b6 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -53,6 +53,12 @@ Int4WeightOnlyConfig( group_size=32, ), + Int4WeightOnlyConfig( + group_size=128, + packing_format="tile_packed_to_4d", + int4_choose_qparams_algorithm="hqq", + version=2, + ), Int8DynamicActivationInt4WeightConfig( group_size=64, ), diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py index 1c0e33c960..337a9d98ad 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py @@ -15,7 +15,6 @@ ) from torchao.quantization import Int4WeightOnlyConfig, quantize_ -from torchao.quantization.quantize_.common.packing_format import PackingFormat from torchao.quantization.quantize_.workflows.int4.int4_tile_packed_to_4d_tensor import ( Int4TilePackedTo4dTensor, ) @@ -25,7 +24,14 @@ INT4_CONFIG = Int4WeightOnlyConfig( group_size=128, - packing_format=PackingFormat.TILE_PACKED_TO_4D, + packing_format="tile_packed_to_4d", + version=2, +) + +INT4_HQQ_CONFIG = Int4WeightOnlyConfig( + group_size=128, + packing_format="tile_packed_to_4d", + int4_choose_qparams_algorithm="hqq", version=2, ) @@ -44,8 +50,8 @@ def setUp(self): ((2, 32, 128), 256, 128), ], ) - def test_linear(self, sizes): - config = INT4_CONFIG + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_linear(self, sizes, config): dtype = torch.bfloat16 device = "cuda" @@ -62,8 +68,8 @@ def test_linear(self, sizes): quantized_and_compiled = compiled_linear(input) self.assertTrue(compute_error(original, quantized_and_compiled) > 20) - def test_module_path(self): - config = INT4_CONFIG + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_module_path(self, config): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) quantize_(linear.cuda(), config) self.assertEqual( @@ -80,11 +86,11 @@ def test_module_path(self): "", ) - def test_slice(self): + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_slice(self, config): """Note: we use multiples of 1024 for both in_features and out_features so that padding does not affect the weight after slicing """ - config = INT4_CONFIG dtype = torch.bfloat16 device = "cuda" @@ -169,8 +175,8 @@ def test_slice(self): res2 = test_linear2(input2) self.assertGreater(compute_error(res_ref2, res2), 14) - def test_slice_preserves_aliasing(self): - config = INT4_CONFIG + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_slice_preserves_aliasing(self, config): l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) l.weight = torch.nn.Parameter( torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") @@ -212,8 +218,9 @@ def test_to_device(self): quantize_(linear, config) linear.to(device) - def test_slice_and_copy_similar_to_vllm(self): - self._test_slice_and_copy_similar_to_vllm(INT4_CONFIG) + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_slice_and_copy_similar_to_vllm(self, config): + self._test_slice_and_copy_similar_to_vllm(config) @parametrize("device", ["cuda"]) @parametrize("dtype", [torch.bfloat16]) diff --git a/torchao/core/config.py b/torchao/core/config.py index 26e71360e2..72a22df020 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -196,6 +196,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: "torchao.dtypes", "torchao.prototype.awq", "torchao.quantization.quantize_.common", + "torchao.quantization.quantize_.workflows", } diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e83abd3953..e505cf0c49 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -72,6 +72,7 @@ ) from torchao.quantization.quantize_.workflows import ( Float8Tensor, + Int4ChooseQParamsAlgorithm, Int4MarlinSparseTensor, Int4OpaqueTensor, Int4PreshuffledTensor, @@ -1054,27 +1055,29 @@ def _gemlite_uintx_weight_only_transform( @dataclass class Int4WeightOnlyConfig(AOBaseConfig): """ - Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using - "tensor_core_tiled" layout for speedup with tinygemm kernel - - Note: - This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm` - and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference - of quantization algorithm compared to the more traditional type of integer quantization is the following: - 1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`) - 2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`) - please follow the relevant code in `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` - to learn about how the quantization parameters are chosen and how the Tensor is quantized/dequantized for tinygemm + Configuration for int4 weight only quantization, only groupwise quantization is supported + right now, and we support version 1 and version 2, that are implemented differently although with + same support. In version 2, different target are mainly distinguished by `packing_format` arg, and in version 1, mainly by `layout`. Args: `group_size`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained, choices are [256, 128, 64, 32] - `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` - `use_hqq`: whether to use hqq or default quantization mode, default is False - `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] - `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. - `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT - `packing_format`: the packing format for int4 tensor, available from version 2 and above + size is more fine grained, choices are [256, 128, 64, 32], used in both version 1 and 2 + `packing_format`: the packing format for int4 tensor, used in version 2 only + `int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4, + currently support TINYGEMM ("tinygemm") and HQQ ("hqq"), used in version 2 only + `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`, used in version 1 only + `use_hqq`: whether to use hqq or default quantization mode, default is False, used in version 1 only + `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE], used in version 1 only + `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. used in both version 1 and 2 + `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT, used in version 1 only + `version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 1, see note for more details + + Note: + Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2 + + For v2 (version = 2), only `group_size`, `packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored + For v1 (version = 1), only `group_size`, `layout`, `use_hqq`, `zero_point_domain`, `preserve_zero` and `set_inductor_config` are valid, we plan to deprecate v1 in torchao 0.15 to make this config + less confusing """ group_size: int = 128 @@ -1085,6 +1088,9 @@ class Int4WeightOnlyConfig(AOBaseConfig): preserve_zero: Optional[bool] = None # only used in version >= 2 packing_format: PackingFormat = PackingFormat.PLAIN + int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = ( + Int4ChooseQParamsAlgorithm.TINYGEMM + ) version: int = 1 def __post_init__(self): @@ -1105,6 +1111,7 @@ def _int4_weight_only_quantize_tensor(weight, config): group_size = config.group_size layout = config.layout use_hqq = config.use_hqq + int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm zero_point_domain = config.zero_point_domain packing_format = config.packing_format @@ -1118,6 +1125,12 @@ def _int4_weight_only_quantize_tensor(weight, config): if config.version == 2: block_size = list(block_size) + + if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: + assert packing_format == PackingFormat.TILE_PACKED_TO_4D, ( + f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {packing_format}, it's only supported by PackingFormat.TILE_PACKED_TO_4D curretnly" + ) + if packing_format == PackingFormat.PRESHUFFLED: new_weight = Int4PreshuffledTensor.from_hp( weight, @@ -1147,6 +1160,7 @@ def _int4_weight_only_quantize_tensor(weight, config): new_weight = Int4TilePackedTo4dTensor.from_hp( weight, block_size, + int4_choose_qparams_algorithm=int4_choose_qparams_algorithm, ) return new_weight else: diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index fb4c6bcc11..700019dca4 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -2,6 +2,7 @@ Float8Tensor, QuantizeTensorToFloat8Kwargs, ) +from .int4.int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm from .int4.int4_marlin_sparse_tensor import ( Int4MarlinSparseTensor, ) @@ -33,4 +34,5 @@ "Int4OpaqueTensor", "IntxUnpackedTensor", "IntxUnpackedToInt8Tensor", + "Int4ChooseQParamsAlgorithm", ] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_choose_qparams_algorithm.py b/torchao/quantization/quantize_/workflows/int4/int4_choose_qparams_algorithm.py new file mode 100644 index 0000000000..2258b3f3e2 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_choose_qparams_algorithm.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class Int4ChooseQParamsAlgorithm(str, Enum): + """Variant of quantization algorithm to calculate scale and zero_point""" + + """ + The choose qparams algorithm native for tinygemm kernel: + scale = (max_val - min_val) / float(quant_max - quant_min), where + max_val and min_val are the max/min for the slice of input Tensor based on block_size + quant_max and quant_min and max/min for the quantized value, e.g. 0, 15 for uint4 + zero_point = min_val + scale * mid_point, where + mid_point = (quant_max + quant_min + 1) / 2 + + implemented in `torchao.quantization.quant_primitives._choose_qparams_affine_tinygemm + """ + TINYGEMM = "tinygemm" + + """ + The choose qparams based on half-quadratic quantization: https://mobiusml.github.io/hqq_blog/ + + implemented in `torchao.quantization.quant_primitives._choose_qparams_and_quantize_affine_hqq` + """ + HQQ = "hqq" diff --git a/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py index f7237932df..6c80198b9f 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py @@ -5,12 +5,22 @@ # LICENSE file in the root directory of this source tree. +import math from typing import List import torch +from torchao.quantization.quant_primitives import ( + MappingType, + _choose_qparams_affine_tinygemm, + _choose_qparams_and_quantize_affine_hqq, + _quantize_affine_tinygemm, +) +from torchao.quantization.utils import pack_tinygemm_scales_and_zeros from torchao.utils import TorchAOBaseTensor, fill_defaults, find_multiple +from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm + __all__ = [ "Int4TilePackedTo4dTensor", ] @@ -76,6 +86,7 @@ def from_hp( cls, hp_tensor: torch.Tensor, block_size: List[int], + int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = Int4ChooseQParamsAlgorithm.TINYGEMM, ): assert len(block_size) == hp_tensor.ndim, ( f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {hp_tensor.ndim=}" @@ -115,34 +126,60 @@ def from_hp( quant_min = 0 quant_max = 15 - from torchao.quantization.quant_primitives import ( - MappingType, - _choose_qparams_affine_tinygemm, - _quantize_affine_tinygemm, - ) - - # Calculate scale and zero_point for tinygemm - scale, zero_point = _choose_qparams_affine_tinygemm( - hp_tensor_padded, - mapping_type=MappingType.ASYMMETRIC, - block_size=tuple(block_size), - target_dtype=target_dtype, - quant_min=quant_min, - quant_max=quant_max, - scale_dtype=hp_tensor.dtype, - zero_point_dtype=hp_tensor.dtype, - ) + # we support two paths for constructing a Int4TilePackedTo4dTensor + # 1. use [hqq](https://mobiusml.github.io/hqq_blog/) algorithm to compute + # scale and zero_point, then convert to the format that's compatible with tinygemm kernels + # 2. don't use hqq, use default tinygemm algorithm to compute scale and zero_point + # + # both approach should have the same speed since both are using tinygemm kernel for gemm + # 1. typically will have higher accuracy compared to 2. + if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: + nbits = int(math.log2(quant_max + 1)) + axis = 1 + group_size = block_size[-1] + compute_dtype = hp_tensor_padded.dtype + device = hp_tensor_padded.device + int_data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq( + hp_tensor_padded, + nbits=nbits, + group_size=group_size, + axis=axis, + compute_dtype=compute_dtype, + device=device, + verbose=False, + raw_output=False, + # raw_output=False is basically the 'convert to tinygemm zero_point version' option (add scale*midpoint) that's used in TilePackedTo4d + # note _choose_qparams_affine_tinygemm does this same thing + ) + int_data = int_data.to(target_dtype) + else: + assert ( + int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.TINYGEMM + ), ( + f"Unsupported Int4ChooseQParamsAlgorithm: {int4_choose_qparams_algorithm}" + ) + # Calculate scale and zero_point for tinygemm + scale, zero_point = _choose_qparams_affine_tinygemm( + hp_tensor_padded, + mapping_type=MappingType.ASYMMETRIC, + block_size=tuple(block_size), + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + scale_dtype=hp_tensor.dtype, + zero_point_dtype=hp_tensor.dtype, + ) - # Quantize for tinygemm - int_data = _quantize_affine_tinygemm( - hp_tensor_padded, - block_size, - scale, - zero_point, - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - ) + # Quantize for tinygemm + int_data = _quantize_affine_tinygemm( + hp_tensor_padded, + block_size, + scale, + zero_point, + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) # Convert to packed format def quant_2d(int_data_2d): @@ -175,8 +212,6 @@ def quant_2d(int_data_2d): else None ) - from torchao.quantization.utils import pack_tinygemm_scales_and_zeros - scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) return cls(