diff --git a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py index 5c21db8c6b..456f834389 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py @@ -26,10 +26,11 @@ ) -def get_config(group_size): +def get_config(group_size, use_hqq): return Int4WeightOnlyConfig( group_size=group_size, int4_packing_format="opaque", + int4_choose_qparams_algorithm="hqq" if use_hqq else "tinygemm", ) @@ -45,13 +46,14 @@ class TestInt4OpaqueTensor(TestCase): ) @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @parametrize("group_size", [32, 64, 128]) - def test_linear(self, sizes, dtype, group_size): + @parametrize("use_hqq", [True, False]) + def test_linear(self, sizes, dtype, group_size, use_hqq): device = "cpu" M, N, K = sizes input = torch.randn(*M, K, dtype=dtype, device=device) linear = torch.nn.Linear(K, N, dtype=dtype, device=device) original = linear(input) - quantize_(linear, get_config(group_size)) + quantize_(linear, get_config(group_size, use_hqq)) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) @@ -60,9 +62,10 @@ def test_linear(self, sizes, dtype, group_size): self.assertTrue(compute_error(original, quantized_and_compiled) > 20) @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) - def test_module_path(self, dtype): + @parametrize("use_hqq", [True, False]) + def test_module_path(self, dtype, use_hqq): linear = torch.nn.Linear(128, 256, dtype=dtype) - quantize_(linear, get_config(group_size=128)) + quantize_(linear, get_config(group_size=128, use_hqq=use_hqq)) self.assertEqual( str(type(linear.weight)), "", @@ -77,12 +80,13 @@ def test_module_path(self, dtype): "", ) - def test_activation_prescaling(self): + @parametrize("use_hqq", [True, False]) + def test_activation_prescaling(self, use_hqq): dtype = torch.bfloat16 input = torch.randn(1, 128, dtype=dtype) linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype) original_output = linear(input) - quantize_(linear, get_config(group_size=128)) + quantize_(linear, get_config(group_size=128, use_hqq=use_hqq)) qw = linear.weight assert isinstance(qw, SupportsActivationPreScaling), ( "Expected int4 tensor supports activation prescaling" diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ef4b247819..60eb3762b1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1085,7 +1085,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): Args: `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32], used in both version 1 and 2 - `packing_format`: the packing format for int4 tensor, used in version 2 only + `int4_packing_format`: the packing format for int4 tensor, used in version 2 only `int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4, currently support TINYGEMM ("tinygemm") and HQQ ("hqq"), used in version 2 only `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`, used in version 1 only @@ -1093,7 +1093,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE], used in version 1 only `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. used in both version 1 and 2 `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT, used in version 1 only - `version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 1, see note for more details + `version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 2, see note for more details Note: Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2 @@ -1150,8 +1150,12 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size = list(block_size) if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: - assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, ( - f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D curretnly" + assert int4_packing_format in [ + Int4PackingFormat.TILE_PACKED_TO_4D, + Int4PackingFormat.OPAQUE, + ], ( + f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, " + f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D and Int4PackingFormat.OPAQUE currently" ) if int4_packing_format == Int4PackingFormat.PRESHUFFLED: @@ -1183,6 +1187,7 @@ def _int4_weight_only_quantize_tensor(weight, config): new_weight = Int4OpaqueTensor.from_hp( weight, block_size, + int4_choose_qparams_algorithm=int4_choose_qparams_algorithm, ) return new_weight elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D: diff --git a/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py index f418950069..57245f55a7 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. +import math from typing import List, Optional import torch @@ -12,12 +13,16 @@ from torchao.quantization.quant_primitives import ( MappingType, _choose_qparams_affine_tinygemm, + _choose_qparams_and_quantize_affine_hqq, _quantize_affine_tinygemm, ) +from torchao.quantization.utils import pack_tinygemm_scales_and_zeros from torchao.utils import ( TorchAOBaseTensor, ) +from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm + __all__ = [ "Int4OpaqueTensor", ] @@ -95,6 +100,7 @@ def from_hp( cls, w: torch.Tensor, block_size: List[int], + int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = Int4ChooseQParamsAlgorithm.TINYGEMM, ): assert w.ndim == 2 and w.device.type == "cpu", ( f"Expecting 2D tensor on CPU, but got: {w.shape} on {w.device.type}" @@ -111,26 +117,54 @@ def from_hp( eps = 1e-6 scale_dtype = None zero_point_dtype = w.dtype - scale, zero_point = _choose_qparams_affine_tinygemm( - w, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - scale_dtype, - zero_point_dtype, - ) - int_data = _quantize_affine_tinygemm( - w, - block_size, - scale, - zero_point, - target_dtype, - quant_min, - quant_max, - ) + + # we support two paths for constructing a Int4OpaqueTensor + # 1. use [hqq](https://mobiusml.github.io/hqq_blog/) algorithm to compute + # scale and zero_point, then convert to the format that's compatible with tinygemm kernels + # 2. don't use hqq, use default tinygemm algorithm to compute scale and zero_point + # + # both approach should have the same performance since both are using CPU tinygemm kernel for gemm + # 1. typically will have higher accuracy compared to 2. + if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: + nbits = int(math.log2(quant_max + 1)) + axis = 1 + group_size = block_size[-1] + int_data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq( + w, + nbits=nbits, + group_size=group_size, + axis=axis, + compute_dtype=zero_point_dtype, + device=w.device, + ) + int_data = int_data.to(target_dtype) + else: + assert ( + int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.TINYGEMM + ), ( + f"Unsupported Int4ChooseQParamsAlgorithm: {int4_choose_qparams_algorithm}" + ) + + scale, zero_point = _choose_qparams_affine_tinygemm( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + int_data = _quantize_affine_tinygemm( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) assert int_data.dtype == torch.int32, ( "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" ) @@ -141,7 +175,6 @@ def from_hp( scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) - from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) return Int4OpaqueTensor(