From bb02a9e1a2bcb62be85c7115e0f210475bb15b66 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 24 Nov 2025 16:58:24 +0000 Subject: [PATCH] Move Int4OpaqueTensor to prototype --- test/prototype/test_awq.py | 3 +- .../test_int4_opaque_tensor.py | 13 ++- torchao/prototype/awq/example.py | 9 +- .../inference_workflow.py | 2 +- .../prototype/int4_opaque_tensor/__init__.py | 7 ++ .../int4_opaque_tensor/inference_workflow.py | 88 +++++++++++++++++++ .../int4_opaque_tensor}/int4_opaque_tensor.py | 7 +- torchao/quantization/__init__.py | 2 - torchao/quantization/quant_api.py | 15 +--- .../quantize_/workflows/__init__.py | 4 - .../workflows/int4/int4_packing_format.py | 7 -- 11 files changed, 112 insertions(+), 45 deletions(-) rename test/{quantization/quantize_/workflows/int4 => prototype}/test_int4_opaque_tensor.py (91%) create mode 100644 torchao/prototype/int4_opaque_tensor/__init__.py create mode 100644 torchao/prototype/int4_opaque_tensor/inference_workflow.py rename torchao/{quantization/quantize_/workflows/int4 => prototype/int4_opaque_tensor}/int4_opaque_tensor.py (98%) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 70bca35f90..a5dc944ae8 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -15,6 +15,7 @@ ) from torchao.prototype.awq import AWQConfig, AWQStep +from torchao.prototype.int4_opaque_tensor import Int4WeightOnlyOpaqueTensorConfig from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.utils import _is_fbgemm_gpu_genai_available, torch_version_at_least @@ -76,7 +77,7 @@ def forward(self, x): # Note: the functionality unit test doesn't work for hqq Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d"), ], - "cpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="opaque")], + "cpu": [Int4WeightOnlyOpaqueTensorConfig(group_size=128)], "xpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="plain_int32")], } diff --git a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py b/test/prototype/test_int4_opaque_tensor.py similarity index 91% rename from test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py rename to test/prototype/test_int4_opaque_tensor.py index 456f834389..638ee001bc 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py +++ b/test/prototype/test_int4_opaque_tensor.py @@ -15,10 +15,8 @@ run_tests, ) -from torchao.quantization import ( - Int4WeightOnlyConfig, - quantize_, -) +from torchao.prototype.int4_opaque_tensor import Int4WeightOnlyOpaqueTensorConfig +from torchao.quantization import quantize_ from torchao.quantization.quantize_.common import SupportsActivationPreScaling from torchao.quantization.utils import compute_error from torchao.utils import ( @@ -27,9 +25,8 @@ def get_config(group_size, use_hqq): - return Int4WeightOnlyConfig( + return Int4WeightOnlyOpaqueTensorConfig( group_size=group_size, - int4_packing_format="opaque", int4_choose_qparams_algorithm="hqq" if use_hqq else "tinygemm", ) @@ -68,7 +65,7 @@ def test_module_path(self, dtype, use_hqq): quantize_(linear, get_config(group_size=128, use_hqq=use_hqq)) self.assertEqual( str(type(linear.weight)), - "", + "", ) with tempfile.NamedTemporaryFile() as f: @@ -77,7 +74,7 @@ def test_module_path(self, dtype, use_hqq): state_dict = torch.load(f) self.assertEqual( str(type(state_dict["weight"])), - "", + "", ) @parametrize("use_hqq", [True, False]) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 2750c42b3a..da55ee96da 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -17,6 +17,7 @@ from torchao.prototype.awq import ( AWQConfig, ) +from torchao.prototype.int4_opaque_tensor import Int4WeightOnlyOpaqueTensorConfig from torchao.quantization import Int4WeightOnlyConfig, quantize_ @@ -259,9 +260,7 @@ def quantize_and_eval( group_size=group_size, int4_packing_format="plain_int32" ) elif device == "cpu": - base_config = Int4WeightOnlyConfig( - group_size=group_size, int4_packing_format="opaque" - ) + base_config = Int4WeightOnlyOpaqueTensorConfig(group_size=group_size) else: assert False, "Unsupported device: {}".format(device) print(f"running {quant} prepare and calibrate") @@ -301,9 +300,7 @@ def quantize_and_eval( if device == "cuda": base_config = Int4WeightOnlyConfig(group_size=group_size) elif device == "cpu": - base_config = Int4WeightOnlyConfig( - group_size=group_size, int4_packing_format="opaque" - ) + base_config = Int4WeightOnlyOpaqueTensorConfig(group_size=group_size) else: assert False, "Unsupported device: {}".format(device) quantize_(model, base_config) diff --git a/torchao/prototype/float8_opaque_tensor/inference_workflow.py b/torchao/prototype/float8_opaque_tensor/inference_workflow.py index 163c8066df..ceef4f5903 100644 --- a/torchao/prototype/float8_opaque_tensor/inference_workflow.py +++ b/torchao/prototype/float8_opaque_tensor/inference_workflow.py @@ -55,7 +55,7 @@ class Float8DynamicActivationFloat8WeightOpaqueTensorConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once( - "torchao.quantization.Float8DynamicActivationFloat8WeightConfig" + "torchao.prototype.float8_opaque_tensor.Float8DynamicActivationFloat8WeightOpaqueTensorConfig" ) activation_granularity, weight_granularity = ( Float8OpaqueTensor._normalize_and_check_granularity(self.granularity) diff --git a/torchao/prototype/int4_opaque_tensor/__init__.py b/torchao/prototype/int4_opaque_tensor/__init__.py new file mode 100644 index 0000000000..29352457cb --- /dev/null +++ b/torchao/prototype/int4_opaque_tensor/__init__.py @@ -0,0 +1,7 @@ +from .inference_workflow import Int4WeightOnlyOpaqueTensorConfig +from .int4_opaque_tensor import Int4OpaqueTensor + +__all__ = [ + "Int4OpaqueTensor", + "Int4WeightOnlyOpaqueTensorConfig", +] diff --git a/torchao/prototype/int4_opaque_tensor/inference_workflow.py b/torchao/prototype/int4_opaque_tensor/inference_workflow.py new file mode 100644 index 0000000000..d1a92d2b2f --- /dev/null +++ b/torchao/prototype/int4_opaque_tensor/inference_workflow.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass + +import torch + +import torchao +from torchao.core.config import AOBaseConfig + +logger = logging.getLogger(__name__) +import types + +from torchao.quantization.quant_api import _linear_extra_repr +from torchao.quantization.quantize_.workflows import ( + Int4ChooseQParamsAlgorithm, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) + +from .int4_opaque_tensor import Int4OpaqueTensor + + +@dataclass +class Int4WeightOnlyOpaqueTensorConfig(AOBaseConfig): + """ + Configuration for int4 weight only quantization, only groupwise quantization is supported right now. + + Args: + `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] + `int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4, currently support TINYGEMM ("tinygemm") and HQQ ("hqq") + `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values + """ + + group_size: int = 128 + int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = ( + Int4ChooseQParamsAlgorithm.TINYGEMM + ) + set_inductor_config: bool = True + + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.prototype.int4_opaque_tensor.Int4WeightOnlyOpaqueTensorConfig" + ) + + +def _int4_weight_only_opaque_tensor_quantize(weight, config): + group_size = config.group_size + int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm + + if weight.shape[-1] % group_size != 0: + logger.info( + f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" + ) + return weight + + block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) + + block_size = list(block_size) + + new_weight = Int4OpaqueTensor.from_hp( + weight, + block_size, + int4_choose_qparams_algorithm=int4_choose_qparams_algorithm, + ) + return new_weight + + +@register_quantize_module_handler(Int4WeightOnlyOpaqueTensorConfig) +def _int4_weight_only_transform( + module: torch.nn.Module, config: Int4WeightOnlyOpaqueTensorConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying int4 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _int4_weight_only_opaque_tensor_quantize(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module diff --git a/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py b/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py similarity index 98% rename from torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py rename to torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py index 708635014a..976f219167 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py +++ b/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py @@ -16,13 +16,14 @@ _choose_qparams_and_quantize_affine_hqq, _quantize_affine_tinygemm, ) +from torchao.quantization.quantize_.workflows import ( + Int4ChooseQParamsAlgorithm, +) from torchao.quantization.utils import pack_tinygemm_scales_and_zeros from torchao.utils import ( TorchAOBaseTensor, ) -from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm - __all__ = [ "Int4OpaqueTensor", ] @@ -241,7 +242,7 @@ def _(func, types, args, kwargs): return y.to(orig_dtype) -Int4OpaqueTensor.__module__ = "torchao.quantization" +Int4OpaqueTensor.__module__ = "torchao.prototype.int4_opaque_tensor" # Allow a model with Int4OpaqueTensor weights to be loaded with `weights_only=True` torch.serialization.add_safe_globals([Int4OpaqueTensor]) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 8dd6410597..a1ca6b0b94 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -94,7 +94,6 @@ from .quantize_.workflows import ( Float8Tensor, Int4MarlinSparseTensor, - Int4OpaqueTensor, Int4PlainInt32Tensor, Int4PreshuffledTensor, Int4Tensor, @@ -173,7 +172,6 @@ "IntxUnpackedToInt8Tensor", "Int4TilePackedTo4dTensor", "Float8Tensor", - "Int4OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c3ab192d8e..bf08691cd8 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -77,7 +77,6 @@ Float8Tensor, Int4ChooseQParamsAlgorithm, Int4MarlinSparseTensor, - Int4OpaqueTensor, Int4PackingFormat, Int4PlainInt32Tensor, Int4PreshuffledTensor, @@ -1163,12 +1162,9 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size = list(block_size) if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: - assert int4_packing_format in [ - Int4PackingFormat.TILE_PACKED_TO_4D, - Int4PackingFormat.OPAQUE, - ], ( + assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, ( f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, " - f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D and Int4PackingFormat.OPAQUE currently" + f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D currently" ) if int4_packing_format == Int4PackingFormat.PRESHUFFLED: @@ -1196,13 +1192,6 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size, ) return new_weight - elif int4_packing_format == Int4PackingFormat.OPAQUE: - new_weight = Int4OpaqueTensor.from_hp( - weight, - block_size, - int4_choose_qparams_algorithm=int4_choose_qparams_algorithm, - ) - return new_weight elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D: new_weight = Int4TilePackedTo4dTensor.from_hp( weight, diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4307637f8e..c1d1ae3f74 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -6,9 +6,6 @@ from .int4.int4_marlin_sparse_tensor import ( Int4MarlinSparseTensor, ) -from .int4.int4_opaque_tensor import ( - Int4OpaqueTensor, -) from .int4.int4_packing_format import Int4PackingFormat from .int4.int4_plain_int32_tensor import ( Int4PlainInt32Tensor, @@ -39,7 +36,6 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", - "Int4OpaqueTensor", "Int4ChooseQParamsAlgorithm", "Int4PackingFormat", "IntxChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py b/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py index b5d988ef4a..de06583d38 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py @@ -48,10 +48,3 @@ class Int4PackingFormat(str, Enum): for simplication of Int4TilePackedTo4dTensor API """ TILE_PACKED_TO_4D = "tile_packed_to_4d" - - """ - Opaque packing format that's used for tensors that does not have a predefined packing format - (that may be decided on hardware, tensor shape, library availability etc.) and it's not - needed for the rest of the system to understand the specific format that's adopted. - """ - OPAQUE = "opaque"