diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py index 60148cbe81..3ba31bf86a 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -23,11 +23,11 @@ MappingType, quantize_, ) -from torchao.quantization.granularity import PerGroup +from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig from torchao.quantization.quantize_.common import PackingFormat from torchao.quantization.utils import compute_error -from torchao.utils import torch_version_at_least +from torchao.utils import torch_version_at_least, unwrap_tensor_subclass @unittest.skipIf(not torch_version_at_least("2.7.0"), "Need pytorch 2.7+") @@ -156,7 +156,7 @@ def test_export_int8_dyn_act_intx_weight_config(self): model, Int8DynamicActivationIntxWeightConfig( weight_dtype=torch.int4, - weight_granularity=PerGroup(64), + weight_granularity=PerAxis(0), weight_mapping_type=MappingType.SYMMETRIC, packing_format=PackingFormat.UNPACKED_TO_INT8, version=2, @@ -169,17 +169,52 @@ def test_export_int8_dyn_act_intx_weight_config(self): exported_results = exported.module()(activations) self.assertTrue(torch.allclose(eager_results, exported_results)) - expected_lines = [ - "torch.ops.torchao.choose_qparams_affine.default", - "torch.ops.torchao.quantize_affine.default", - "torch.ops.torchao.dequantize_affine.default", - "torch.ops.torchao.dequantize_affine.default", - "torch.ops.aten.linear.default", + expected_counts = { + "torch.ops.torchao.choose_qparams_affine.default": 1, + "torch.ops.torchao.quantize_affine.default": 1, + "torch.ops.torchao.dequantize_affine.default": 2, + "torch.ops.aten.linear.default": 1, + "torch.ops.aten.reshape.default": 0, + } + for line, count in expected_counts.items(): + FileCheck().check_count(line, count, exactly=True).run( + exported.graph_module.code + ) + + def test_export_int8_dyn_act_intx_weight_config_with_unwrap(self): + layers = [ + torch.nn.Linear(512, 256, bias=False), ] - for line in expected_lines: - count = 1 - if line == "torch.ops.torchao.dequantize_affine.default": - count = 2 + model = torch.nn.Sequential(*layers) + activations = torch.randn(1, 512, dtype=torch.float32) + + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(64), + weight_mapping_type=MappingType.SYMMETRIC, + packing_format=PackingFormat.UNPACKED_TO_INT8, + version=2, + ), + ) + eager_results = model(activations) + + unwrap_tensor_subclass(model) + + exported = torch.export.export(model, (activations,)) + + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) + + expected_counts = { + "torch.ops.torchao.choose_qparams_affine.default": 1, + "torch.ops.torchao.quantize_affine.default": 1, + "torch.ops.torchao.dequantize_affine.default": 2, + "torch.ops.aten.linear.default": 1, + "torch.ops.aten.reshape.default": 0, + } + for line, count in expected_counts.items(): FileCheck().check_count(line, count, exactly=True).run( exported.graph_module.code ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 682d07a2b1..453b04e4e0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -833,7 +833,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): block_size, weight_dtype, mapping_type=weight_mapping_type, - apply_int8_act_asym_per_token_quant=True, + activation_quantization="int8_asym_per_token", ) if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype: _adjust_scale_dtype_in_intx_unpacked_tensor( diff --git a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py index a0808eaf19..6f17a66d2f 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py @@ -15,6 +15,7 @@ from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( IntxUnpackedToInt8Tensor, + IntxUnpackedToInt8TensorActivationQuantization, ) from torchao.utils import ( TorchAOBaseTensor, @@ -144,7 +145,10 @@ def from_intx_unpacked_to_int8_tensor( compute_target = ComputeTarget[compute_target.upper()] # Extract data from IntxUnpackedToInt8Tensor - assert tensor.apply_int8_act_asym_per_token_quant + assert ( + tensor.activation_quantization + == IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN + ) qdata, scale, zero_point = tensor.qdata, tensor.scale, tensor.zero_point bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype] dtype = tensor.dtype diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index 400e842967..e9d79fc670 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. -from typing import List, Tuple +import enum +from typing import List, Optional, Tuple import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -32,6 +33,14 @@ _FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32] +class IntxUnpackedToInt8TensorActivationQuantization(str, enum.Enum): + """ + This applies int8 asymmetric activation quantization per token. + """ + + INT8_ASYM_PER_TOKEN = "int8_asym_per_token" + + class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): """ intx quantization with unpacked format. Subbyte quantized data is represented as int8. @@ -55,7 +64,7 @@ class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8) block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) dtype: the dtype of the dequantized Tensor - apply_int8_act_asym_per_token_quant: bool, whether to apply activation quantization to the dequantized Tensor during linear. Use False for weight-only quantization + activation_quantization: Optional[IntxUnpackedToInt8TensorActivationQuantization] = None, kind of activation quantization to apply. Default is None, which means weight-only quantization """ tensor_data_names = ["qdata", "scale", "zero_point"] @@ -63,7 +72,7 @@ class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): "target_dtype", "block_size", "dtype", - "apply_int8_act_asym_per_token_quant", + "activation_quantization", ] def __new__( @@ -74,7 +83,7 @@ def __new__( target_dtype, block_size, dtype, - apply_int8_act_asym_per_token_quant, + activation_quantization, ): kwargs = {} kwargs["device"] = qdata.device @@ -91,7 +100,7 @@ def __init__( target_dtype, block_size, dtype, - apply_int8_act_asym_per_token_quant, + activation_quantization, ): super().__init__() assert qdata.dtype == torch.int8, ( @@ -113,8 +122,14 @@ def __init__( for i in range(len(block_size)): assert qdata.shape[i] % block_size[i] == 0 n_blocks.append(qdata.shape[i] // block_size[i]) - scale = scale.reshape(*n_blocks) - zero_point = zero_point.reshape(*n_blocks) + + # Assert shapes + assert scale.shape == tuple(n_blocks), ( + f"Expected scale to have shape {n_blocks} (inferred from block_size={block_size}), but got {scale.shape}" + ) + assert zero_point.shape == tuple(n_blocks), ( + f"Expected zero_point to have shape {n_blocks} (inferred from block_size={block_size}), but got {zero_point.shape}" + ) assert dtype in _FLOAT_TYPES, ( f"dtype must be one of {_FLOAT_TYPES}, but got {dtype}" @@ -126,10 +141,10 @@ def __init__( self.target_dtype = target_dtype self.block_size = block_size - self.apply_int8_act_asym_per_token_quant = apply_int8_act_asym_per_token_quant + self.activation_quantization = activation_quantization def _quantization_type(self): - return f"target_dtype={self.target_dtype}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}, apply_int8_act_asym_per_token_quant={self.apply_int8_act_asym_per_token_quant}" + return f"target_dtype={self.target_dtype}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}, activation_quantization={self.activation_quantization}" def _has_float_zero_point(self) -> bool: return self.zero_point.dtype in _FLOAT_TYPES @@ -148,7 +163,7 @@ def to(self, *args, **kwargs): self.target_dtype, self.block_size, dtype, - self.apply_int8_act_asym_per_token_quant, + self.activation_quantization, ) @classmethod @@ -159,7 +174,9 @@ def from_hp( target_dtype: torch.dtype, *, mapping_type: MappingType = MappingType.SYMMETRIC, - apply_int8_act_asym_per_token_quant: bool = False, + activation_quantization: Optional[ + IntxUnpackedToInt8TensorActivationQuantization + ] = None, ): """ Create an IntxUnpackedToInt8Tensor from a high-precision tensor @@ -183,6 +200,16 @@ def from_hp( quant_min=qmin, quant_max=qmax, ) + + # Reshape scale and zero_point to be compatible with block_size + # This is asserted in IntxUnpackedToInt8Tensor's __init__ + n_blocks = [] + for i in range(len(block_size)): + assert qdata.shape[i] % block_size[i] == 0 + n_blocks.append(qdata.shape[i] // block_size[i]) + scale = scale.reshape(*n_blocks) + zero_point = zero_point.reshape(*n_blocks) + return IntxUnpackedToInt8Tensor( qdata=qdata, scale=scale, @@ -190,7 +217,7 @@ def from_hp( target_dtype=target_dtype, block_size=block_size, dtype=hp_tensor.dtype, - apply_int8_act_asym_per_token_quant=apply_int8_act_asym_per_token_quant, + activation_quantization=activation_quantization, ) def dequantize(self): @@ -207,6 +234,42 @@ def dequantize(self): ) +def _apply_int8_act_asym_per_token_quant_dequant(hp_tensor): + target_dtype = torch.int8 + mapping_type = MappingType.ASYMMETRIC + block_size = _get_per_token_block_size(hp_tensor) + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] + scale, zero_point = choose_qparams_affine( + hp_tensor, + mapping_type, + block_size, + target_dtype=target_dtype, + quant_min=qmin, + quant_max=qmax, + zero_point_dtype=torch.int8, + ) + qdata = quantize_affine( + hp_tensor, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + dequantized_affine = dequantize_affine( + qdata, + block_size, + scale, + zero_point, + torch.int8, + qmin, + qmax, + output_dtype=hp_tensor.dtype, + ) + return dequantized_affine + + implements = IntxUnpackedToInt8Tensor.implements @@ -220,13 +283,16 @@ def _(func, types, args, kwargs): assert isinstance(weight_tensor, IntxUnpackedToInt8Tensor) # Apply dynamic activation quant - if weight_tensor.apply_int8_act_asym_per_token_quant: - input_tensor = IntxUnpackedToInt8Tensor.from_hp( - hp_tensor=input_tensor, - block_size=_get_per_token_block_size(input_tensor), - target_dtype=torch.int8, - mapping_type=MappingType.ASYMMETRIC, - ).dequantize() + if weight_tensor.activation_quantization is not None: + if ( + weight_tensor.activation_quantization + == IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN + ): + input_tensor = _apply_int8_act_asym_per_token_quant_dequant(input_tensor) + else: + raise NotImplementedError( + f"Unsupported activation quantization: {weight_tensor.activation_quantization}" + ) weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) @@ -293,7 +359,7 @@ def _(func, types, args, kwargs): self.target_dtype, new_block_size, self.dtype, - self.apply_int8_act_asym_per_token_quant, + self.activation_quantization, ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -301,4 +367,6 @@ def _(func, types, args, kwargs): IntxUnpackedToInt8Tensor.__module__ = "torchao.quantization" # Allow a model with IntxUnpackedToInt8Tensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([IntxUnpackedToInt8Tensor]) +torch.serialization.add_safe_globals( + [IntxUnpackedToInt8Tensor, IntxUnpackedToInt8TensorActivationQuantization] +)