From 8d9e77a978c9174fe0d852ce0fe02b3ec26556d0 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Sep 2025 13:00:22 -0700 Subject: [PATCH 1/6] Fix int8_unpacked for xnnpack export --- .../intx/test_intx_unpacked_to_int8_tensor.py | 19 ++++---- .../intx/intx_unpacked_to_int8_tensor.py | 44 ++++++++++++++++--- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py index 60148cbe81..ee85937cf6 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -169,17 +169,14 @@ def test_export_int8_dyn_act_intx_weight_config(self): exported_results = exported.module()(activations) self.assertTrue(torch.allclose(eager_results, exported_results)) - expected_lines = [ - "torch.ops.torchao.choose_qparams_affine.default", - "torch.ops.torchao.quantize_affine.default", - "torch.ops.torchao.dequantize_affine.default", - "torch.ops.torchao.dequantize_affine.default", - "torch.ops.aten.linear.default", - ] - for line in expected_lines: - count = 1 - if line == "torch.ops.torchao.dequantize_affine.default": - count = 2 + expected_counts = { + "torch.ops.torchao.choose_qparams_affine.default": 1, + "torch.ops.torchao.quantize_affine.default": 1, + "torch.ops.torchao.dequantize_affine.default": 2, + "torch.ops.aten.linear.default": 1, + "torch.ops.aten.reshape.default": 0, + } + for line, count in expected_counts.items(): FileCheck().check_count(line, count, exactly=True).run( exported.graph_module.code ) diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index 400e842967..b9b4b282af 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -31,6 +31,43 @@ _FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32] +def _fake_apply_int8_act_asym_per_token_quant(hp_tensor): + target_dtype = torch.int8 + mapping_type = MappingType.ASYMMETRIC + block_size = _get_per_token_block_size(hp_tensor) + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] + scale, zero_point = choose_qparams_affine( + hp_tensor, + mapping_type, + block_size, + target_dtype=target_dtype, + quant_min=qmin, + quant_max=qmax, + zero_point_dtype=torch.int8, + ) + qdata = quantize_affine( + hp_tensor, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + dequantized_affine = dequantize_affine( + qdata, + block_size, + scale, + zero_point, + torch.int8, + qmin, + qmax, + output_dtype=hp_tensor.dtype, + ) + return dequantized_affine + + + class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): """ @@ -221,12 +258,7 @@ def _(func, types, args, kwargs): # Apply dynamic activation quant if weight_tensor.apply_int8_act_asym_per_token_quant: - input_tensor = IntxUnpackedToInt8Tensor.from_hp( - hp_tensor=input_tensor, - block_size=_get_per_token_block_size(input_tensor), - target_dtype=torch.int8, - mapping_type=MappingType.ASYMMETRIC, - ).dequantize() + input_tensor = _fake_apply_int8_act_asym_per_token_quant(input_tensor) weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) From 75134f33c7fbd24a2df4048c107d594e76bb2ee3 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Sep 2025 13:45:31 -0700 Subject: [PATCH 2/6] up --- .../intx/test_intx_unpacked_to_int8_tensor.py | 40 ++++++++++++++++++- .../intx/intx_unpacked_to_int8_tensor.py | 9 +++-- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py index ee85937cf6..a8a76a709f 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -27,7 +27,7 @@ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig from torchao.quantization.quantize_.common import PackingFormat from torchao.quantization.utils import compute_error -from torchao.utils import torch_version_at_least +from torchao.utils import torch_version_at_least, unwrap_tensor_subclass @unittest.skipIf(not torch_version_at_least("2.7.0"), "Need pytorch 2.7+") @@ -181,6 +181,44 @@ def test_export_int8_dyn_act_intx_weight_config(self): exported.graph_module.code ) + def test_export_int8_dyn_act_intx_weight_config_with_unwrap(self): + layers = [ + torch.nn.Linear(512, 256, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(1, 512, dtype=torch.float32) + + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(64), + weight_mapping_type=MappingType.SYMMETRIC, + packing_format=PackingFormat.UNPACKED_TO_INT8, + version=2, + ), + ) + eager_results = model(activations) + + unwrap_tensor_subclass(model) + + exported = torch.export.export(model, (activations,)) + + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) + + expected_counts = { + "torch.ops.torchao.choose_qparams_affine.default": 1, + "torch.ops.torchao.quantize_affine.default": 1, + "torch.ops.torchao.dequantize_affine.default": 2, + "torch.ops.aten.linear.default": 1, + "torch.ops.aten.reshape.default": 0, + } + for line, count in expected_counts.items(): + FileCheck().check_count(line, count, exactly=True).run( + exported.graph_module.code + ) + def test_serialization_int8_dyn_act_intx_weight_config(self): layers = [ torch.nn.Linear(512, 256), diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index b9b4b282af..68488f20d9 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -31,6 +31,7 @@ _FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32] + def _fake_apply_int8_act_asym_per_token_quant(hp_tensor): target_dtype = torch.int8 mapping_type = MappingType.ASYMMETRIC @@ -67,8 +68,6 @@ def _fake_apply_int8_act_asym_per_token_quant(hp_tensor): return dequantized_affine - - class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): """ intx quantization with unpacked format. Subbyte quantized data is represented as int8. @@ -150,8 +149,10 @@ def __init__( for i in range(len(block_size)): assert qdata.shape[i] % block_size[i] == 0 n_blocks.append(qdata.shape[i] // block_size[i]) - scale = scale.reshape(*n_blocks) - zero_point = zero_point.reshape(*n_blocks) + + # Assert shapes + assert scale.shape == tuple(n_blocks) + assert zero_point.shape == tuple(n_blocks) assert dtype in _FLOAT_TYPES, ( f"dtype must be one of {_FLOAT_TYPES}, but got {dtype}" From f6b9888c4c1c19d603317fceec4773a0ca9b07fc Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Sep 2025 14:17:42 -0700 Subject: [PATCH 3/6] up --- .../intx/intx_unpacked_to_int8_tensor.py | 80 ++++++++++--------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index 68488f20d9..a65988c01a 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -32,42 +32,6 @@ _FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32] -def _fake_apply_int8_act_asym_per_token_quant(hp_tensor): - target_dtype = torch.int8 - mapping_type = MappingType.ASYMMETRIC - block_size = _get_per_token_block_size(hp_tensor) - qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] - scale, zero_point = choose_qparams_affine( - hp_tensor, - mapping_type, - block_size, - target_dtype=target_dtype, - quant_min=qmin, - quant_max=qmax, - zero_point_dtype=torch.int8, - ) - qdata = quantize_affine( - hp_tensor, - block_size, - scale, - zero_point, - output_dtype=torch.int8, - quant_min=qmin, - quant_max=qmax, - ) - dequantized_affine = dequantize_affine( - qdata, - block_size, - scale, - zero_point, - torch.int8, - qmin, - qmax, - output_dtype=hp_tensor.dtype, - ) - return dequantized_affine - - class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): """ intx quantization with unpacked format. Subbyte quantized data is represented as int8. @@ -151,8 +115,12 @@ def __init__( n_blocks.append(qdata.shape[i] // block_size[i]) # Assert shapes - assert scale.shape == tuple(n_blocks) - assert zero_point.shape == tuple(n_blocks) + assert scale.shape == tuple(n_blocks), ( + f"Expected scale to have shape {n_blocks} (inferred from block_size={block_size}), but got {scale.shape}" + ) + assert zero_point.shape == tuple(n_blocks), ( + f"Expected zero_point to have shape {n_blocks} (inferred from block_size={block_size}), but got {zero_point.shape}" + ) assert dtype in _FLOAT_TYPES, ( f"dtype must be one of {_FLOAT_TYPES}, but got {dtype}" @@ -245,6 +213,42 @@ def dequantize(self): ) +def _fake_apply_int8_act_asym_per_token_quant(hp_tensor): + target_dtype = torch.int8 + mapping_type = MappingType.ASYMMETRIC + block_size = _get_per_token_block_size(hp_tensor) + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] + scale, zero_point = choose_qparams_affine( + hp_tensor, + mapping_type, + block_size, + target_dtype=target_dtype, + quant_min=qmin, + quant_max=qmax, + zero_point_dtype=torch.int8, + ) + qdata = quantize_affine( + hp_tensor, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + dequantized_affine = dequantize_affine( + qdata, + block_size, + scale, + zero_point, + torch.int8, + qmin, + qmax, + output_dtype=hp_tensor.dtype, + ) + return dequantized_affine + + implements = IntxUnpackedToInt8Tensor.implements From e3de2203951976fa5c3d194b0806fda1a9028233 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Sep 2025 15:43:48 -0700 Subject: [PATCH 4/6] up --- torchao/quantization/quant_api.py | 2 +- .../workflows/intx/intx_opaque_tensor.py | 5 +- .../intx/intx_unpacked_to_int8_tensor.py | 55 +++++++++++++------ 3 files changed, 44 insertions(+), 18 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 682d07a2b1..453b04e4e0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -833,7 +833,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): block_size, weight_dtype, mapping_type=weight_mapping_type, - apply_int8_act_asym_per_token_quant=True, + activation_quantization="int8_asym_per_token", ) if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype: _adjust_scale_dtype_in_intx_unpacked_tensor( diff --git a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py index a0808eaf19..a34ed3391a 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py @@ -15,6 +15,7 @@ from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( IntxUnpackedToInt8Tensor, + ActivationQuantization, ) from torchao.utils import ( TorchAOBaseTensor, @@ -144,7 +145,9 @@ def from_intx_unpacked_to_int8_tensor( compute_target = ComputeTarget[compute_target.upper()] # Extract data from IntxUnpackedToInt8Tensor - assert tensor.apply_int8_act_asym_per_token_quant + assert ( + tensor.activation_quantization == ActivationQuantization.INT8_ASYM_PER_TOKEN + ) qdata, scale, zero_point = tensor.qdata, tensor.scale, tensor.zero_point bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype] dtype = tensor.dtype diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index a65988c01a..1756a8d73a 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -5,8 +5,8 @@ # LICENSE file in the root directory of this source tree. -from typing import List, Tuple - +from typing import List, Tuple, Optional, Union +import enum import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -32,6 +32,14 @@ _FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32] +class ActivationQuantization(enum.Enum): + """ + This applies int8 asymmetric activation quantization per token. + """ + + INT8_ASYM_PER_TOKEN = "int8_asym_per_token" + + class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): """ intx quantization with unpacked format. Subbyte quantized data is represented as int8. @@ -55,7 +63,7 @@ class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8) block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) dtype: the dtype of the dequantized Tensor - apply_int8_act_asym_per_token_quant: bool, whether to apply activation quantization to the dequantized Tensor during linear. Use False for weight-only quantization + activation_quantization: Optional[ActivationQuantization] = None, kind of activation quantization to apply. Default is None, which means weight-only quantization """ tensor_data_names = ["qdata", "scale", "zero_point"] @@ -63,7 +71,7 @@ class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): "target_dtype", "block_size", "dtype", - "apply_int8_act_asym_per_token_quant", + "activation_quantization", ] def __new__( @@ -74,7 +82,7 @@ def __new__( target_dtype, block_size, dtype, - apply_int8_act_asym_per_token_quant, + activation_quantization, ): kwargs = {} kwargs["device"] = qdata.device @@ -91,7 +99,7 @@ def __init__( target_dtype, block_size, dtype, - apply_int8_act_asym_per_token_quant, + activation_quantization, ): super().__init__() assert qdata.dtype == torch.int8, ( @@ -132,10 +140,10 @@ def __init__( self.target_dtype = target_dtype self.block_size = block_size - self.apply_int8_act_asym_per_token_quant = apply_int8_act_asym_per_token_quant + self.activation_quantization = activation_quantization def _quantization_type(self): - return f"target_dtype={self.target_dtype}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}, apply_int8_act_asym_per_token_quant={self.apply_int8_act_asym_per_token_quant}" + return f"target_dtype={self.target_dtype}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}, activation_quantization={self.activation_quantization}" def _has_float_zero_point(self) -> bool: return self.zero_point.dtype in _FLOAT_TYPES @@ -154,7 +162,7 @@ def to(self, *args, **kwargs): self.target_dtype, self.block_size, dtype, - self.apply_int8_act_asym_per_token_quant, + self.activation_quantization, ) @classmethod @@ -165,11 +173,18 @@ def from_hp( target_dtype: torch.dtype, *, mapping_type: MappingType = MappingType.SYMMETRIC, - apply_int8_act_asym_per_token_quant: bool = False, + activation_quantization: Optional[Union[ActivationQuantization, str]] = None, ): """ Create an IntxUnpackedToInt8Tensor from a high-precision tensor """ + if activation_quantization is not None and isinstance( + activation_quantization, str + ): + activation_quantization = ActivationQuantization[ + activation_quantization.upper() + ] + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] scale, zero_point = choose_qparams_affine( hp_tensor, @@ -196,7 +211,7 @@ def from_hp( target_dtype=target_dtype, block_size=block_size, dtype=hp_tensor.dtype, - apply_int8_act_asym_per_token_quant=apply_int8_act_asym_per_token_quant, + activation_quantization=activation_quantization, ) def dequantize(self): @@ -213,7 +228,7 @@ def dequantize(self): ) -def _fake_apply_int8_act_asym_per_token_quant(hp_tensor): +def _apply_int8_act_asym_per_token_quant_dequant(hp_tensor): target_dtype = torch.int8 mapping_type = MappingType.ASYMMETRIC block_size = _get_per_token_block_size(hp_tensor) @@ -262,8 +277,16 @@ def _(func, types, args, kwargs): assert isinstance(weight_tensor, IntxUnpackedToInt8Tensor) # Apply dynamic activation quant - if weight_tensor.apply_int8_act_asym_per_token_quant: - input_tensor = _fake_apply_int8_act_asym_per_token_quant(input_tensor) + if weight_tensor.activation_quantization is not None: + if ( + weight_tensor.activation_quantization + == ActivationQuantization.INT8_ASYM_PER_TOKEN + ): + input_tensor = _apply_int8_act_asym_per_token_quant_dequant(input_tensor) + else: + raise NotImplementedError( + f"Unsupported activation quantization: {weight_tensor.activation_quantization}" + ) weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) @@ -330,7 +353,7 @@ def _(func, types, args, kwargs): self.target_dtype, new_block_size, self.dtype, - self.apply_int8_act_asym_per_token_quant, + self.activation_quantization, ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -338,4 +361,4 @@ def _(func, types, args, kwargs): IntxUnpackedToInt8Tensor.__module__ = "torchao.quantization" # Allow a model with IntxUnpackedToInt8Tensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([IntxUnpackedToInt8Tensor]) +torch.serialization.add_safe_globals([IntxUnpackedToInt8Tensor, ActivationQuantization]) From 5dd71f99d57b81580b9bed6e058f9268fdabd4f5 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Sep 2025 16:38:13 -0700 Subject: [PATCH 5/6] up --- .../intx/test_intx_unpacked_to_int8_tensor.py | 4 +-- .../workflows/intx/intx_opaque_tensor.py | 2 +- .../intx/intx_unpacked_to_int8_tensor.py | 34 ++++++++++++------- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py index a8a76a709f..3ba31bf86a 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -23,7 +23,7 @@ MappingType, quantize_, ) -from torchao.quantization.granularity import PerGroup +from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig from torchao.quantization.quantize_.common import PackingFormat from torchao.quantization.utils import compute_error @@ -156,7 +156,7 @@ def test_export_int8_dyn_act_intx_weight_config(self): model, Int8DynamicActivationIntxWeightConfig( weight_dtype=torch.int4, - weight_granularity=PerGroup(64), + weight_granularity=PerAxis(0), weight_mapping_type=MappingType.SYMMETRIC, packing_format=PackingFormat.UNPACKED_TO_INT8, version=2, diff --git a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py index a34ed3391a..20fb522cbb 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py @@ -14,8 +14,8 @@ from torchao.experimental.op_lib_utils import _check_torchao_ops_loaded from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( - IntxUnpackedToInt8Tensor, ActivationQuantization, + IntxUnpackedToInt8Tensor, ) from torchao.utils import ( TorchAOBaseTensor, diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index 1756a8d73a..e9d79fc670 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -5,8 +5,9 @@ # LICENSE file in the root directory of this source tree. -from typing import List, Tuple, Optional, Union import enum +from typing import List, Optional, Tuple + import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -32,7 +33,7 @@ _FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32] -class ActivationQuantization(enum.Enum): +class IntxUnpackedToInt8TensorActivationQuantization(str, enum.Enum): """ This applies int8 asymmetric activation quantization per token. """ @@ -63,7 +64,7 @@ class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8) block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) dtype: the dtype of the dequantized Tensor - activation_quantization: Optional[ActivationQuantization] = None, kind of activation quantization to apply. Default is None, which means weight-only quantization + activation_quantization: Optional[IntxUnpackedToInt8TensorActivationQuantization] = None, kind of activation quantization to apply. Default is None, which means weight-only quantization """ tensor_data_names = ["qdata", "scale", "zero_point"] @@ -173,18 +174,13 @@ def from_hp( target_dtype: torch.dtype, *, mapping_type: MappingType = MappingType.SYMMETRIC, - activation_quantization: Optional[Union[ActivationQuantization, str]] = None, + activation_quantization: Optional[ + IntxUnpackedToInt8TensorActivationQuantization + ] = None, ): """ Create an IntxUnpackedToInt8Tensor from a high-precision tensor """ - if activation_quantization is not None and isinstance( - activation_quantization, str - ): - activation_quantization = ActivationQuantization[ - activation_quantization.upper() - ] - qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] scale, zero_point = choose_qparams_affine( hp_tensor, @@ -204,6 +200,16 @@ def from_hp( quant_min=qmin, quant_max=qmax, ) + + # Reshape scale and zero_point to be compatible with block_size + # This is asserted in IntxUnpackedToInt8Tensor's __init__ + n_blocks = [] + for i in range(len(block_size)): + assert qdata.shape[i] % block_size[i] == 0 + n_blocks.append(qdata.shape[i] // block_size[i]) + scale = scale.reshape(*n_blocks) + zero_point = zero_point.reshape(*n_blocks) + return IntxUnpackedToInt8Tensor( qdata=qdata, scale=scale, @@ -280,7 +286,7 @@ def _(func, types, args, kwargs): if weight_tensor.activation_quantization is not None: if ( weight_tensor.activation_quantization - == ActivationQuantization.INT8_ASYM_PER_TOKEN + == IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN ): input_tensor = _apply_int8_act_asym_per_token_quant_dequant(input_tensor) else: @@ -361,4 +367,6 @@ def _(func, types, args, kwargs): IntxUnpackedToInt8Tensor.__module__ = "torchao.quantization" # Allow a model with IntxUnpackedToInt8Tensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([IntxUnpackedToInt8Tensor, ActivationQuantization]) +torch.serialization.add_safe_globals( + [IntxUnpackedToInt8Tensor, IntxUnpackedToInt8TensorActivationQuantization] +) From aa67cb37cab71613182a88b7da63bcb3ae05e7c8 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Sep 2025 16:41:17 -0700 Subject: [PATCH 6/6] up --- .../quantize_/workflows/intx/intx_opaque_tensor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py index 20fb522cbb..6f17a66d2f 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py @@ -14,8 +14,8 @@ from torchao.experimental.op_lib_utils import _check_torchao_ops_loaded from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( - ActivationQuantization, IntxUnpackedToInt8Tensor, + IntxUnpackedToInt8TensorActivationQuantization, ) from torchao.utils import ( TorchAOBaseTensor, @@ -146,7 +146,8 @@ def from_intx_unpacked_to_int8_tensor( # Extract data from IntxUnpackedToInt8Tensor assert ( - tensor.activation_quantization == ActivationQuantization.INT8_ASYM_PER_TOKEN + tensor.activation_quantization + == IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN ) qdata, scale, zero_point = tensor.qdata, tensor.scale, tensor.zero_point bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype]