From 715ec749c965a85db9ed606239a7f0cf95b2a720 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 5 Sep 2025 12:40:08 -0700 Subject: [PATCH] Move packing format used by int4 to int4_packing_format.py Summary: We found that there is not much reuse of packing format, so we now plan to define packing format for each of the dtype (int4, float8, intx), instead of having a global packing_format that's used by all the tensors this reduces the interference between different dtype configs. This doesn't change tensor subclass, so no BC changes for tensor subclass. For v2 of Int4WeightOnlyConfig, it breaks BC, but we don't have any official models saved with this config yet, so it's fine. We also didn't add bc testing for this since it's not finalized yet. We'll add that later. Test Plan: Regression tests: python test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py python test/core/test_config.py python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags: --- test/core/test_config.py | 4 +- .../int4/test_int4_marlin_sparse_tensor.py | 2 +- .../workflows/int4/test_int4_opaque_tensor.py | 2 +- .../int4/test_int4_plain_int32_tensor.py | 2 +- .../int4/test_int4_preshuffled_tensor.py | 4 +- .../workflows/int4/test_int4_tensor.py | 11 +++- .../test_int4_tile_packed_to_4d_tensor.py | 4 +- torchao/quantization/quant_api.py | 35 ++++++------ .../quantize_/common/packing_format.py | 23 +------- .../quantize_/workflows/__init__.py | 2 + .../workflows/int4/int4_packing_format.py | 57 +++++++++++++++++++ 11 files changed, 97 insertions(+), 49 deletions(-) create mode 100644 torchao/quantization/quantize_/workflows/int4/int4_packing_format.py diff --git a/test/core/test_config.py b/test/core/test_config.py index c7c412f9b6..0bf975fa3b 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -26,6 +26,7 @@ from torchao.quantization.quant_api import ( FbgemmConfig, Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, GemliteUIntXWeightOnlyConfig, @@ -49,13 +50,14 @@ weight_dtype=torch.float8_e4m3fn, ), UIntXWeightOnlyConfig(dtype=torch.uint1), + Float8DynamicActivationInt4WeightConfig(), Int4DynamicActivationInt4WeightConfig(), Int4WeightOnlyConfig( group_size=32, ), Int4WeightOnlyConfig( group_size=128, - packing_format="tile_packed_to_4d", + int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2, ), diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index cc8f10faba..d6961dfa23 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -26,7 +26,7 @@ BF16_ACT_CONFIG = Int4WeightOnlyConfig( group_size=128, - packing_format="marlin_sparse", + int4_packing_format="marlin_sparse", version=2, ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py index 58ec391038..0b3e84fb77 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py @@ -28,7 +28,7 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, - packing_format="opaque", + int4_packing_format="opaque", version=2, ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index d7d793685e..728ebd880a 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -28,7 +28,7 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, - packing_format="plain_int32", + int4_packing_format="plain_int32", version=2, ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py index 01ef99ae96..4760f75257 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -29,13 +29,13 @@ BF16_ACT_CONFIG = Int4WeightOnlyConfig( group_size=128, - packing_format="preshuffled", + int4_packing_format="preshuffled", version=2, ) # only 128 group_size is supported FP8_ACT_CONFIG = Float8DynamicActivationInt4WeightConfig( - packing_format="preshuffled", + int4_packing_format="preshuffled", ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py index a72d3b1d2c..a971db609e 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py @@ -17,17 +17,24 @@ from torchao.quantization.quantize_.common import SupportsActivationPreScaling from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase -from torchao.utils import is_sm_at_least_90, torch_version_at_least +from torchao.utils import ( + _is_fbgemm_genai_gpu_available, + is_sm_at_least_90, + torch_version_at_least, +) @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") +@unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" +) class TestInt4Tensor(TorchAOIntegrationTestCase): def setUp(self): self.config = Int4WeightOnlyConfig( group_size=128, - packing_format="plain", + int4_packing_format="plain", version=2, ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py index 337a9d98ad..64519e327a 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py @@ -24,13 +24,13 @@ INT4_CONFIG = Int4WeightOnlyConfig( group_size=128, - packing_format="tile_packed_to_4d", + int4_packing_format="tile_packed_to_4d", version=2, ) INT4_HQQ_CONFIG = Int4WeightOnlyConfig( group_size=128, - packing_format="tile_packed_to_4d", + int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2, ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9a138dc9d1..37776cb06b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -75,6 +75,7 @@ Int4ChooseQParamsAlgorithm, Int4MarlinSparseTensor, Int4OpaqueTensor, + Int4PackingFormat, Int4PlainInt32Tensor, Int4PreshuffledTensor, Int4Tensor, @@ -1075,7 +1076,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): Note: Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2 - For v2 (version = 2), only `group_size`, `packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored + For v2 (version = 2), only `group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored For v1 (version = 1), only `group_size`, `layout`, `use_hqq`, `zero_point_domain`, `preserve_zero` and `set_inductor_config` are valid, we plan to deprecate v1 in torchao 0.15 to make this config less confusing """ @@ -1087,7 +1088,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): set_inductor_config: bool = True preserve_zero: Optional[bool] = None # only used in version >= 2 - packing_format: PackingFormat = PackingFormat.PLAIN + int4_packing_format: Int4PackingFormat = Int4PackingFormat.PLAIN int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = ( Int4ChooseQParamsAlgorithm.TINYGEMM ) @@ -1113,7 +1114,7 @@ def _int4_weight_only_quantize_tensor(weight, config): use_hqq = config.use_hqq int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm zero_point_domain = config.zero_point_domain - packing_format = config.packing_format + int4_packing_format = config.int4_packing_format if weight.shape[-1] % group_size != 0: logger.info( @@ -1127,42 +1128,42 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size = list(block_size) if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: - assert packing_format == PackingFormat.TILE_PACKED_TO_4D, ( - f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {packing_format}, it's only supported by PackingFormat.TILE_PACKED_TO_4D curretnly" + assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, ( + f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D curretnly" ) - if packing_format == PackingFormat.PRESHUFFLED: + if int4_packing_format == Int4PackingFormat.PRESHUFFLED: new_weight = Int4PreshuffledTensor.from_hp( weight, block_size, activation_dtype=torch.bfloat16, ) return new_weight - elif packing_format == PackingFormat.PLAIN: + elif int4_packing_format == Int4PackingFormat.PLAIN: new_weight = Int4Tensor.from_hp( weight, block_size, ) return new_weight - elif packing_format == PackingFormat.PLAIN_INT32: + elif int4_packing_format == Int4PackingFormat.PLAIN_INT32: new_weight = Int4PlainInt32Tensor.from_hp( weight, block_size, ) return new_weight - elif packing_format == PackingFormat.MARLIN_SPARSE: + elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE: new_weight = Int4MarlinSparseTensor.from_hp( weight, block_size, ) return new_weight - elif packing_format == PackingFormat.OPAQUE: + elif int4_packing_format == Int4PackingFormat.OPAQUE: new_weight = Int4OpaqueTensor.from_hp( weight, block_size, ) return new_weight - elif packing_format == PackingFormat.TILE_PACKED_TO_4D: + elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D: new_weight = Int4TilePackedTo4dTensor.from_hp( weight, block_size, @@ -1170,7 +1171,7 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight else: - raise ValueError(f"Unsupported packing format: {packing_format}") + raise ValueError(f"Unsupported int4 packing format: {int4_packing_format}") assert config.version == 1 @@ -1254,10 +1255,10 @@ class Float8DynamicActivationInt4WeightConfig(AOBaseConfig): and above and no benefits of making it bigger) Args: - `packing_format`: how the weight is packed, only preshuffled is supported + `int4_packing_format`: how the weight is packed, only preshuffled is supported """ - packing_format: PackingFormat = "preshuffled" + int4_packing_format: Int4PackingFormat = "preshuffled" @register_quantize_module_handler(Float8DynamicActivationInt4WeightConfig) @@ -1268,10 +1269,10 @@ def _float8_dynamic_activation_int4_weight_transform( "applying int8 weight only quant requires module to have weight attribute" + " but {module} does not have one" ) - packing_format = config.packing_format + int4_packing_format = config.int4_packing_format - assert packing_format == "preshuffled", ( - f"only preshuffled packing_format supported right now, got: {packing_format}" + assert int4_packing_format == "preshuffled", ( + f"only preshuffled int4_packing_format supported right now, got: {int4_packing_format}" ) weight = module.weight group_size = 128 diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 94d45917b9..bedb8b5986 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -16,7 +16,7 @@ class PackingFormat(str, Enum): """ plain means the format that quantized Tensor data lays out elements in Tensor sequentially, - for example: for a Tensor of shape (4, 6): + for example: for a Tensor of shape (4, 6): a_0_0, a_0_1, ..., a_0_5, ... a_3_0, a_3_1, ..., a_3_5 @@ -26,32 +26,11 @@ class PackingFormat(str, Enum): """ PLAIN = "plain" - """ - preshuffled is referring to the preshuffled format used by fbgemm kernels - """ - PRESHUFFLED = "preshuffled" - - """ - marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization - """ - MARLIN_SPARSE = "marlin_sparse" - """ Unpacked to int8 means the subbyte quantized data is stored as int8 """ UNPACKED_TO_INT8 = "unpacked_to_int8" - """ - plain_int32 is referring to the format used by int4 weight-only quantization. - which is a groupwise quantization format 2*int4 is store in a byte and 4*(int4*2) is stored in a int32. - """ - PLAIN_INT32 = "plain_int32" - - """ - tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization - """ - TILE_PACKED_TO_4D = "tile_packed_to_4d" - """ Opaque packing format that's used for tensors that does not have a predefined packing format (that may be decided on hardware, tensor shape, library availability etc.) and it's not diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4a762adc5d..0459d230f0 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -9,6 +9,7 @@ from .int4.int4_opaque_tensor import ( Int4OpaqueTensor, ) +from .int4.int4_packing_format import Int4PackingFormat from .int4.int4_plain_int32_tensor import ( Int4PlainInt32Tensor, ) @@ -39,4 +40,5 @@ "IntxUnpackedTensor", "IntxUnpackedToInt8Tensor", "Int4ChooseQParamsAlgorithm", + "Int4PackingFormat", ] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py b/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py new file mode 100644 index 0000000000..b5d988ef4a --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class Int4PackingFormat(str, Enum): + """Packing format for quantized data in Int4 Tensor subclasses in torchao, represents how + the values in quantized data are packed and laid out in memory. + """ + + """ + plain means the format that quantized Tensor data lays out elements in Tensor sequentially, + for example: for a Tensor of shape (4, 6): + a_0_0, a_0_1, ..., a_0_5, + ... + a_3_0, a_3_1, ..., a_3_5 + + For example for int4, we will + pack two adjacent int4 elements into one uint8/int8 value for plain packing format + """ + PLAIN = "plain" + + """ + preshuffled is referring to the preshuffled format used by fbgemm kernels + """ + PRESHUFFLED = "preshuffled" + + """ + marlin_sparse is referring to the format used by marlin kernels, requires symmetric quantization + """ + MARLIN_SPARSE = "marlin_sparse" + + """ + plain_int32 is a format that 2 adjacent int4 values are packed in a byte and 4 such packed bytes are stored in a int32 value. + """ + PLAIN_INT32 = "plain_int32" + + """ + tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization + for a Tensor of shape (n, k), the packed weight will have dimension: + [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2], where inner_k_tiles is 8 currently + for simplication of Int4TilePackedTo4dTensor API + """ + TILE_PACKED_TO_4D = "tile_packed_to_4d" + + """ + Opaque packing format that's used for tensors that does not have a predefined packing format + (that may be decided on hardware, tensor shape, library availability etc.) and it's not + needed for the rest of the system to understand the specific format that's adopted. + """ + OPAQUE = "opaque"