From 2fba4c849d845fedf5af8b25c18859af30d25e17 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 1 Oct 2025 16:18:26 -0700 Subject: [PATCH 1/6] init --- torchao/quantization/qat/api.py | 9 ++ torchao/quantization/quant_api.py | 20 +++++ torchao/quantization/quant_primitives.py | 89 +++++++++++++++++++ .../quantize_/workflows/__init__.py | 2 + .../intx/intx_choose_qparams_algorithm.py | 23 +++++ .../intx/intx_unpacked_to_int8_tensor.py | 76 ++++++++++++---- 6 files changed, 201 insertions(+), 18 deletions(-) create mode 100644 torchao/quantization/quantize_/workflows/intx/intx_choose_qparams_algorithm.py diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 1287126bac..6cc4ab0fb8 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import copy from dataclasses import dataclass from enum import Enum from typing import Any, List, Optional, Tuple @@ -232,6 +233,7 @@ def _qat_config_transform( # Optionally pass custom scales and zero points to base config handler # This is only for range learning and only applies to weights kwargs = {} + has_custom_scale_and_zero_point = False weight_config = module.weight_fake_quantizer.config if ( isinstance(weight_config, IntxFakeQuantizeConfig) @@ -239,6 +241,7 @@ def _qat_config_transform( ): kwargs["custom_scale"] = module.weight_fake_quantizer.scale kwargs["custom_zero_point"] = module.weight_fake_quantizer.zero_point + has_custom_scale_and_zero_point = True # Swap FakeQuantizedLinear -> nn.Linear # Swap FakeQuantizedEmbedding -> nn.Embedding @@ -253,6 +256,12 @@ def _qat_config_transform( f"Encountered unexpected module {module}, should never happen" ) if base_config is not None: + # If passing custom scales and zero points, we need to disable the choose_qparam_algorithm on the config + if has_custom_scale_and_zero_point and hasattr( + base_config, "intx_choose_qparams_algorithm" + ): + base_config = copy.deepcopy(base_config) + base_config.intx_choose_qparams_algorithm = None return _QUANTIZE_CONFIG_HANDLER[type(base_config)]( module, base_config, **kwargs ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 15caddcadc..cd7da6c614 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -78,6 +78,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + IntxChooseQParamsAlgorithm, IntxOpaqueTensor, IntxPackingFormat, IntxUnpackedToInt8Tensor, @@ -748,6 +749,7 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): `intx_packing_format`: The format to use for the packed weight tensor (version 2 only). - unpacked_to_int8: this format is the default and is intended for export applications like ExecuTorch. - opaque_torchao_auto: this format is optimized for CPU performance. + `intx_choose_qparams_algorithm`: The algorithm to use for choosing the quantization parameters. `version`: version of the config to use, only subset of above args are valid based on version, see note for more details. Note: @@ -766,6 +768,9 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): act_mapping_type: MappingType = MappingType.ASYMMETRIC layout: Layout = QDQLayout() intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8 + intx_choose_qparams_algorithm: IntxChooseQParamsAlgorithm = ( + IntxChooseQParamsAlgorithm.AFFINE + ) version: int = 2 @@ -830,6 +835,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor( act_mapping_type = config.act_mapping_type layout = config.layout intx_packing_format = config.intx_packing_format + intx_choose_qparams_algorithm = config.intx_choose_qparams_algorithm assert weight.dim() == 2, ( f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}" @@ -868,6 +874,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor( weight_dtype, mapping_type=weight_mapping_type, activation_quantization="int8_asym_per_token", + intx_choose_qparams_algorithm=intx_choose_qparams_algorithm, custom_scale=custom_scale, custom_zero_point=custom_zero_point, ) @@ -889,6 +896,9 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor( # Version 1 assert config.version == 1 + assert intx_choose_qparams_algorithm == IntxChooseQParamsAlgorithm.AFFINE, ( + "IntxChooseQParamsAlgorithm.AFFINE is the only supported algorithm for version 1" + ) warnings.warn( "Config Deprecation: version 1 of Int8DynamicActivationIntxWeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2967 for more details" ) @@ -2169,6 +2179,7 @@ class IntxWeightOnlyConfig(AOBaseConfig): - QDQLayout: this layout is designed for export to ExecuTorch.this layout represents the quantization with Q/DQ quant primitives, and is intended for export applications like ExecuTorch. `intx_packing_format`: The format to use for the packed weight tensor (version 2 only). + `intx_choose_qparams_algorithm`: The algorithm to use for choosing the quantization parameters. `version`: version of the config to use, only subset of above args are valid based on version, see note for more details. Note: @@ -2185,6 +2196,9 @@ class IntxWeightOnlyConfig(AOBaseConfig): scale_dtype: Optional[torch.dtype] = None layout: Layout = QDQLayout() intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8 + intx_choose_qparams_algorithm: IntxChooseQParamsAlgorithm = ( + IntxChooseQParamsAlgorithm.AFFINE + ) version: int = 2 def __post_init__(self): @@ -2202,6 +2216,7 @@ def __post_init__(self): assert self.mapping_type in [ MappingType.ASYMMETRIC, MappingType.SYMMETRIC, + MappingType.SYMMETRIC_NO_CLIPPING_ERR, ], ( f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" ) @@ -2220,6 +2235,7 @@ def _intx_weight_only_quantize_tensor( scale_dtype = config.scale_dtype layout = config.layout intx_packing_format = config.intx_packing_format + intx_choose_qparams_algorithm = config.intx_choose_qparams_algorithm assert weight.dim() == 2, ( f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" @@ -2247,6 +2263,7 @@ def _intx_weight_only_quantize_tensor( mapping_type=mapping_type, custom_scale=custom_scale, custom_zero_point=custom_zero_point, + intx_choose_qparams_algorithm=intx_choose_qparams_algorithm, ) if scale_dtype is not None and scale_dtype != weight.dtype: _adjust_scale_dtype_in_intx_unpacked_tensor( @@ -2258,6 +2275,9 @@ def _intx_weight_only_quantize_tensor( raise ValueError(f"Unsupported packing format: {intx_packing_format}") # Version 1 + assert config.intx_choose_qparams_algorithm == IntxChooseQParamsAlgorithm.AFFINE, ( + "version 1 only supports affine algorithm" + ) assert config.version == 1 warnings.warn( "Config Deprecation: version 1 of IntxWeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2967 for more details" diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index cdfbc00c3a..4a806162ca 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -32,6 +32,7 @@ "_choose_qparams_affine_dont_preserve_zero", "_choose_qparams_affine_floatx", "_choose_qparams_and_quantize_affine_hqq", + "_choose_qparams_and_quantize_scale_only_hqq", "_choose_qparams_and_quantize_affine_qqq", "_choose_scale_float8", "_choose_qparams_gguf", @@ -2125,6 +2126,94 @@ def _choose_qparams_and_quantize_affine_hqq( return W_q, scale, zero, shape +@torch.no_grad() +def _choose_qparams_and_quantize_scale_only_hqq( + hp_tensor: torch.Tensor, + block_size: List[int], + qmin: int, + qmax: int, + *, + iters: int = 20, + stochastic: bool = False, + early_stop_tol: float = 1e-5, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Half-Quadratic Quantization (scale-only, symmetric) for 2D weights with row-wise blocks. + - hp_tensor: [out, in] (bf16/fp16/fp32 accepted; promoted to fp32 internally) + - block_size: must be [1, group_size]; groups along the last dim + - qmin, qmax: integer range (e.g., -8, 7 for signed 4-bit) + Returns: + qdata: int32, same shape as hp_tensor + scale: hp_tensor.dtype, shape [out, in // group_size] (one scale per row-wise block) + """ + # --- strict interface guarantees --- + assert hp_tensor.ndim == 2, "hp_tensor must be 2D [out, in]" + assert isinstance(block_size, (list, tuple)) and len(block_size) == 2, ( + "block_size must be a 2-element list/tuple" + ) + assert block_size[0] == 1 and block_size[1] >= 1, ( + "block_size must be [1, group_size] with group_size >= 1" + ) + assert qmin < qmax, "qmin must be < qmax" + + # Promote to fp32 for stable math + compute_dtype = torch.float32 + compute_eps = torch.finfo(compute_dtype).eps + + n, k = hp_tensor.shape + group_size = int(block_size[1]) + assert k % group_size == 0, ( + f"in_features={k} must be divisible by group_size={group_size}" + ) + + def round_det(x: torch.Tensor) -> torch.Tensor: + # ties-to-even; fine for PTQ + return x.round() + + def round_stoch(x: torch.Tensor) -> torch.Tensor: + # unbiased stochastic rounding + return torch.floor(x + torch.rand_like(x)) + + _r = round_stoch if stochastic else round_det + + # Reshape Wg into [n, n_groups, group_size] + W = hp_tensor.to(compute_dtype).contiguous() + n_groups = k // group_size + Wg = W.view(n, n_groups, group_size) + + # Initialize per-block scales as max-abs / qabs + # scale.shape = [n, n_groups] + qabs = max(abs(qmin), abs(qmax)) or 1 + scale = (Wg.abs().amax(dim=2) / qabs).clamp_min(compute_eps) + prev_scale = scale.clone() + + # Iterate HQQ updates + for _ in range(max(1, iters)): + # Quantize using current scale + # Qg.shape = [n, n_groups, group_size] + Qg = _r(Wg / scale.unsqueeze(-1)).clamp(qmin, qmax) + + # Solve least-square problem min_{s} ||Wg - s * Qg||^2 and project + # solution onto positive space, or take previous value + num = (Wg * Qg).sum(dim=2, dtype=torch.float32) # [n, n_groups] + den = (Qg * Qg).sum(dim=2, dtype=torch.float32) # [n, n_groups] + scale = torch.where(den > 0, num / den, prev_scale) + scale = scale.abs().clamp_min(compute_eps) + + rel = ((scale - prev_scale).abs() / prev_scale.clamp_min(compute_eps)).max() + if rel < early_stop_tol: + break + prev_scale = scale + + # Restore shapes + qdata = Qg.view(n, k).contiguous() + + out_dtype = hp_tensor.dtype + scale = scale.to(out_dtype) + + return qdata, scale + + def _choose_qparams_affine_floatx( tensor: torch.Tensor, ebits: int, mbits: int ) -> torch.Tensor: diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 229c94c73a..4307637f8e 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -20,6 +20,7 @@ Int4Tensor, ) from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor +from .intx.intx_choose_qparams_algorithm import IntxChooseQParamsAlgorithm from .intx.intx_opaque_tensor import ( IntxOpaqueTensor, ) @@ -41,6 +42,7 @@ "Int4OpaqueTensor", "Int4ChooseQParamsAlgorithm", "Int4PackingFormat", + "IntxChooseQParamsAlgorithm", "IntxPackingFormat", "IntxUnpackedToInt8Tensor", "IntxOpaqueTensor", diff --git a/torchao/quantization/quantize_/workflows/intx/intx_choose_qparams_algorithm.py b/torchao/quantization/quantize_/workflows/intx/intx_choose_qparams_algorithm.py new file mode 100644 index 0000000000..125bda3757 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/intx/intx_choose_qparams_algorithm.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class IntxChooseQParamsAlgorithm(str, Enum): + """Variant of quantization algorithm to calculate scale and zero_point""" + + """ + Uses `torchao.quantization.quant_primitives.choose_qparams_affine` + """ + AFFINE = "affine" + + """ + Uses `torchao.quantization.quant_primitives._choose_qparams_and_quantize_scale_only_hqq` + """ + HQQ = "hqq" diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index 87402241dd..4c134e7472 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -14,10 +14,14 @@ from torchao.quantization.quant_primitives import ( _DTYPE_TO_QVALUE_BOUNDS, MappingType, + _choose_qparams_and_quantize_scale_only_hqq, choose_qparams_affine, dequantize_affine, quantize_affine, ) +from torchao.quantization.quantize_.workflows.intx.intx_choose_qparams_algorithm import ( + IntxChooseQParamsAlgorithm, +) from torchao.quantization.utils import _get_per_token_block_size from torchao.utils import ( TorchAOBaseTensor, @@ -177,6 +181,9 @@ def from_hp( activation_quantization: Optional[ IntxUnpackedToInt8TensorActivationQuantization ] = None, + intx_choose_qparams_algorithm: Optional[ + IntxChooseQParamsAlgorithm + ] = IntxChooseQParamsAlgorithm.AFFINE, custom_scale: Optional[torch.Tensor] = None, custom_zero_point: Optional[torch.Tensor] = None, ): @@ -184,9 +191,36 @@ def from_hp( Create an IntxUnpackedToInt8Tensor from a high-precision tensor """ qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] - if custom_scale is not None and custom_zero_point is not None: - scale, zero_point = custom_scale, custom_zero_point - elif custom_scale is None and custom_zero_point is None: + + if intx_choose_qparams_algorithm is not None: + assert custom_scale is None, ( + "custom_scale is not supported with intx_choose_qparams_algorithm" + ) + assert custom_zero_point is None, ( + "custom_zero_point is not supported with intx_choose_qparams_algorithm" + ) + + if intx_choose_qparams_algorithm is None: + assert custom_scale is not None, "custom_scale must be given" + assert custom_zero_point is not None, "custom_zero_point must be given" + scale = custom_scale + zero_point = custom_zero_point + qdata = quantize_affine( + hp_tensor, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + elif intx_choose_qparams_algorithm == IntxChooseQParamsAlgorithm.HQQ: + qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( + hp_tensor, block_size, qmin, qmax + ) + qdata = qdata.to(torch.int8) + zero_point = torch.zeros_like(scale, dtype=torch.int8) + elif intx_choose_qparams_algorithm == IntxChooseQParamsAlgorithm.AFFINE: scale, zero_point = choose_qparams_affine( hp_tensor, mapping_type, @@ -196,19 +230,19 @@ def from_hp( quant_max=qmax, zero_point_dtype=torch.int8, ) + qdata = quantize_affine( + hp_tensor, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) else: raise ValueError( - "`custom_scale` and `custom_zero_point` must be both defined or both None" + f"Unsupported IntxChooseQParamsAlgorithm: {intx_choose_qparams_algorithm}" ) - qdata = quantize_affine( - hp_tensor, - block_size, - scale, - zero_point, - output_dtype=torch.int8, - quant_min=qmin, - quant_max=qmax, - ) # Reshape scale and zero_point to be compatible with block_size # This is asserted in IntxUnpackedToInt8Tensor's __init__ @@ -231,15 +265,21 @@ def from_hp( def dequantize(self): qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.target_dtype] + qdata = self.qdata # .to(torch.int32) + scale = self.scale # .to(torch.float32) + zero_point = self.zero_point # .to(torch.int32) + output_dtype = self.dtype # torch.float32 # self.dtype + dtype = torch.int8 # torch.int8 + return dequantize_affine( - self.qdata, + qdata, self.block_size, - self.scale, - self.zero_point, - torch.int8, + scale, + zero_point, + dtype, qmin, qmax, - output_dtype=self.dtype, + output_dtype=output_dtype, ) From 9290b822f35cb0beb2e075d40da7e2dcdd5a7a6c Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 1 Oct 2025 16:21:50 -0700 Subject: [PATCH 2/6] up --- .../intx/intx_unpacked_to_int8_tensor.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index 4c134e7472..248be4aa03 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -265,21 +265,15 @@ def from_hp( def dequantize(self): qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.target_dtype] - qdata = self.qdata # .to(torch.int32) - scale = self.scale # .to(torch.float32) - zero_point = self.zero_point # .to(torch.int32) - output_dtype = self.dtype # torch.float32 # self.dtype - dtype = torch.int8 # torch.int8 - return dequantize_affine( - qdata, + self.qdata, self.block_size, - scale, - zero_point, - dtype, + self.scale, + self.zero_point, + torch.int8, qmin, qmax, - output_dtype=output_dtype, + output_dtype=self.dtype, ) From 58913e9769fbd7605c1726550944a19e790b8fa5 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 1 Oct 2025 18:41:04 -0700 Subject: [PATCH 3/6] up --- .../intx/test_intx_unpacked_to_int8_tensor.py | 16 ++++++++++++++++ torchao/prototype/parq/quant/quant_api.py | 8 +++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py index 9284c1890e..484d1a75a5 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -61,6 +61,22 @@ def test_linear(self): error = compute_error(original, quantized) self.assertTrue(error > 20) + def test_hqq(self): + dtype = torch.bfloat16 + device = "cpu" + config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + intx_choose_qparams_algorithm="hqq", + ) + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + error = compute_error(original, quantized) + self.assertTrue(error > 20, f"Got error {error}") + def test_slice(self): dtype = torch.bfloat16 device = "cpu" diff --git a/torchao/prototype/parq/quant/quant_api.py b/torchao/prototype/parq/quant/quant_api.py index 608fd9570e..d04344c4a5 100644 --- a/torchao/prototype/parq/quant/quant_api.py +++ b/torchao/prototype/parq/quant/quant_api.py @@ -65,9 +65,11 @@ def quantize_stretched_affine( ) -> torch.Tensor: if target_dtype in _SUB_BYTE_UINT_BOUNDS: target_dtype = torch.uint8 - assert input_float.dtype in (torch.float32, torch.float16, torch.bfloat16), ( - f"Unsupported input_float dtype: {input_float.dtype}" - ) + assert input_float.dtype in ( + torch.float32, + torch.float16, + torch.bfloat16, + ), f"Unsupported input_float dtype: {input_float.dtype}" assert len(block_size) == input_float.dim(), ( f"Got {input_float.dim()=}, {block_size=}" ) From ae46a820217dff5955c37f734c23efd2db2fa305 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 1 Oct 2025 18:41:41 -0700 Subject: [PATCH 4/6] up --- torchao/quantization/quant_primitives.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 4a806162ca..0b66a8f45e 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2198,15 +2198,20 @@ def round_stoch(x: torch.Tensor) -> torch.Tensor: num = (Wg * Qg).sum(dim=2, dtype=torch.float32) # [n, n_groups] den = (Qg * Qg).sum(dim=2, dtype=torch.float32) # [n, n_groups] scale = torch.where(den > 0, num / den, prev_scale) - scale = scale.abs().clamp_min(compute_eps) + scale = scale.clamp_min( + compute_eps + ).abs() # project LS solution onto [eps, inf] rel = ((scale - prev_scale).abs() / prev_scale.clamp_min(compute_eps)).max() if rel < early_stop_tol: break prev_scale = scale + # Quantize using final scale + Qg = _r(Wg / scale.unsqueeze(-1)).clamp(qmin, qmax) + # Restore shapes - qdata = Qg.view(n, k).contiguous() + qdata = Qg.view(n, k).contiguous().to(torch.int32) out_dtype = hp_tensor.dtype scale = scale.to(out_dtype) From adfebdf3dd9aca6a388d568d3557839d1cf33b57 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 1 Oct 2025 18:42:47 -0700 Subject: [PATCH 5/6] up --- torchao/prototype/parq/quant/quant_api.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/parq/quant/quant_api.py b/torchao/prototype/parq/quant/quant_api.py index d04344c4a5..608fd9570e 100644 --- a/torchao/prototype/parq/quant/quant_api.py +++ b/torchao/prototype/parq/quant/quant_api.py @@ -65,11 +65,9 @@ def quantize_stretched_affine( ) -> torch.Tensor: if target_dtype in _SUB_BYTE_UINT_BOUNDS: target_dtype = torch.uint8 - assert input_float.dtype in ( - torch.float32, - torch.float16, - torch.bfloat16, - ), f"Unsupported input_float dtype: {input_float.dtype}" + assert input_float.dtype in (torch.float32, torch.float16, torch.bfloat16), ( + f"Unsupported input_float dtype: {input_float.dtype}" + ) assert len(block_size) == input_float.dim(), ( f"Got {input_float.dim()=}, {block_size=}" ) From be24ef2951402d54bf65ba7e8fc897bcbbb7e836 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 2 Oct 2025 10:37:46 -0700 Subject: [PATCH 6/6] up --- .../intx/test_intx_unpacked_to_int8_tensor.py | 20 +++++++++++++++++-- torchao/quantization/qat/api.py | 2 ++ torchao/quantization/quant_api.py | 2 +- .../intx/intx_choose_qparams_algorithm.py | 2 +- .../intx/intx_unpacked_to_int8_tensor.py | 2 +- 5 files changed, 23 insertions(+), 5 deletions(-) diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py index 484d1a75a5..f49e2b3f8d 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -61,13 +61,29 @@ def test_linear(self): error = compute_error(original, quantized) self.assertTrue(error > 20) - def test_hqq(self): + def test_hqq_intx_weight_only_config(self): dtype = torch.bfloat16 device = "cpu" config = IntxWeightOnlyConfig( weight_dtype=torch.int4, granularity=PerGroup(32), - intx_choose_qparams_algorithm="hqq", + intx_choose_qparams_algorithm="hqq_scale_only", + ) + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + error = compute_error(original, quantized) + self.assertTrue(error > 20, f"Got error {error}") + + def test_hqq_int8_dyn_act_intx_weight_config(self): + dtype = torch.bfloat16 + device = "cpu" + config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + intx_choose_qparams_algorithm="hqq_scale_only", ) input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 6cc4ab0fb8..beb2a76ad9 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +import logging from dataclasses import dataclass from enum import Enum from typing import Any, List, Optional, Tuple @@ -260,6 +261,7 @@ def _qat_config_transform( if has_custom_scale_and_zero_point and hasattr( base_config, "intx_choose_qparams_algorithm" ): + logging.debug("Disabling intx_choose_qparams_algorithm") base_config = copy.deepcopy(base_config) base_config.intx_choose_qparams_algorithm = None return _QUANTIZE_CONFIG_HANDLER[type(base_config)]( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index cd7da6c614..d426832f5a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2218,7 +2218,7 @@ def __post_init__(self): MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR, ], ( - f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" + f"mapping_type must be MappingType.ASYMMETRIC, MappingType.SYMMETRIC, or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.mapping_type}" ) diff --git a/torchao/quantization/quantize_/workflows/intx/intx_choose_qparams_algorithm.py b/torchao/quantization/quantize_/workflows/intx/intx_choose_qparams_algorithm.py index 125bda3757..7e1f459ee0 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_choose_qparams_algorithm.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_choose_qparams_algorithm.py @@ -20,4 +20,4 @@ class IntxChooseQParamsAlgorithm(str, Enum): """ Uses `torchao.quantization.quant_primitives._choose_qparams_and_quantize_scale_only_hqq` """ - HQQ = "hqq" + HQQ_SCALE_ONLY = "hqq_scale_only" diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index 248be4aa03..0da7c9f65e 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -214,7 +214,7 @@ def from_hp( quant_min=qmin, quant_max=qmax, ) - elif intx_choose_qparams_algorithm == IntxChooseQParamsAlgorithm.HQQ: + elif intx_choose_qparams_algorithm == IntxChooseQParamsAlgorithm.HQQ_SCALE_ONLY: qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( hp_tensor, block_size, qmin, qmax )