diff --git a/test/core/test_config.py b/test/core/test_config.py index c7c412f9b6..0bf975fa3b 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -26,6 +26,7 @@ from torchao.quantization.quant_api import ( FbgemmConfig, Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, GemliteUIntXWeightOnlyConfig, @@ -49,13 +50,14 @@ weight_dtype=torch.float8_e4m3fn, ), UIntXWeightOnlyConfig(dtype=torch.uint1), + Float8DynamicActivationInt4WeightConfig(), Int4DynamicActivationInt4WeightConfig(), Int4WeightOnlyConfig( group_size=32, ), Int4WeightOnlyConfig( group_size=128, - packing_format="tile_packed_to_4d", + int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2, ), diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py index cc8f10faba..d6961dfa23 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -26,7 +26,7 @@ BF16_ACT_CONFIG = Int4WeightOnlyConfig( group_size=128, - packing_format="marlin_sparse", + int4_packing_format="marlin_sparse", version=2, ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py index 58ec391038..0b3e84fb77 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py @@ -28,7 +28,7 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, - packing_format="opaque", + int4_packing_format="opaque", version=2, ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index d7d793685e..728ebd880a 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -28,7 +28,7 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, - packing_format="plain_int32", + int4_packing_format="plain_int32", version=2, ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py index 01ef99ae96..4760f75257 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -29,13 +29,13 @@ BF16_ACT_CONFIG = Int4WeightOnlyConfig( group_size=128, - packing_format="preshuffled", + int4_packing_format="preshuffled", version=2, ) # only 128 group_size is supported FP8_ACT_CONFIG = Float8DynamicActivationInt4WeightConfig( - packing_format="preshuffled", + int4_packing_format="preshuffled", ) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py index a72d3b1d2c..a971db609e 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py @@ -17,17 +17,24 @@ from torchao.quantization.quantize_.common import SupportsActivationPreScaling from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase -from torchao.utils import is_sm_at_least_90, torch_version_at_least +from torchao.utils import ( + _is_fbgemm_genai_gpu_available, + is_sm_at_least_90, + torch_version_at_least, +) @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") +@unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" +) class TestInt4Tensor(TorchAOIntegrationTestCase): def setUp(self): self.config = Int4WeightOnlyConfig( group_size=128, - packing_format="plain", + int4_packing_format="plain", version=2, ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py index 337a9d98ad..64519e327a 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py @@ -24,13 +24,13 @@ INT4_CONFIG = Int4WeightOnlyConfig( group_size=128, - packing_format="tile_packed_to_4d", + int4_packing_format="tile_packed_to_4d", version=2, ) INT4_HQQ_CONFIG = Int4WeightOnlyConfig( group_size=128, - packing_format="tile_packed_to_4d", + int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2, ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9a138dc9d1..37776cb06b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -75,6 +75,7 @@ Int4ChooseQParamsAlgorithm, Int4MarlinSparseTensor, Int4OpaqueTensor, + Int4PackingFormat, Int4PlainInt32Tensor, Int4PreshuffledTensor, Int4Tensor, @@ -1075,7 +1076,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): Note: Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2 - For v2 (version = 2), only `group_size`, `packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored + For v2 (version = 2), only `group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored For v1 (version = 1), only `group_size`, `layout`, `use_hqq`, `zero_point_domain`, `preserve_zero` and `set_inductor_config` are valid, we plan to deprecate v1 in torchao 0.15 to make this config less confusing """ @@ -1087,7 +1088,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): set_inductor_config: bool = True preserve_zero: Optional[bool] = None # only used in version >= 2 - packing_format: PackingFormat = PackingFormat.PLAIN + int4_packing_format: Int4PackingFormat = Int4PackingFormat.PLAIN int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = ( Int4ChooseQParamsAlgorithm.TINYGEMM ) @@ -1113,7 +1114,7 @@ def _int4_weight_only_quantize_tensor(weight, config): use_hqq = config.use_hqq int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm zero_point_domain = config.zero_point_domain - packing_format = config.packing_format + int4_packing_format = config.int4_packing_format if weight.shape[-1] % group_size != 0: logger.info( @@ -1127,42 +1128,42 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size = list(block_size) if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: - assert packing_format == PackingFormat.TILE_PACKED_TO_4D, ( - f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {packing_format}, it's only supported by PackingFormat.TILE_PACKED_TO_4D curretnly" + assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, ( + f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D curretnly" ) - if packing_format == PackingFormat.PRESHUFFLED: + if int4_packing_format == Int4PackingFormat.PRESHUFFLED: new_weight = Int4PreshuffledTensor.from_hp( weight, block_size, activation_dtype=torch.bfloat16, ) return new_weight - elif packing_format == PackingFormat.PLAIN: + elif int4_packing_format == Int4PackingFormat.PLAIN: new_weight = Int4Tensor.from_hp( weight, block_size, ) return new_weight - elif packing_format == PackingFormat.PLAIN_INT32: + elif int4_packing_format == Int4PackingFormat.PLAIN_INT32: new_weight = Int4PlainInt32Tensor.from_hp( weight, block_size, ) return new_weight - elif packing_format == PackingFormat.MARLIN_SPARSE: + elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE: new_weight = Int4MarlinSparseTensor.from_hp( weight, block_size, ) return new_weight - elif packing_format == PackingFormat.OPAQUE: + elif int4_packing_format == Int4PackingFormat.OPAQUE: new_weight = Int4OpaqueTensor.from_hp( weight, block_size, ) return new_weight - elif packing_format == PackingFormat.TILE_PACKED_TO_4D: + elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D: new_weight = Int4TilePackedTo4dTensor.from_hp( weight, block_size, @@ -1170,7 +1171,7 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight else: - raise ValueError(f"Unsupported packing format: {packing_format}") + raise ValueError(f"Unsupported int4 packing format: {int4_packing_format}") assert config.version == 1 @@ -1254,10 +1255,10 @@ class Float8DynamicActivationInt4WeightConfig(AOBaseConfig): and above and no benefits of making it bigger) Args: - `packing_format`: how the weight is packed, only preshuffled is supported + `int4_packing_format`: how the weight is packed, only preshuffled is supported """ - packing_format: PackingFormat = "preshuffled" + int4_packing_format: Int4PackingFormat = "preshuffled" @register_quantize_module_handler(Float8DynamicActivationInt4WeightConfig) @@ -1268,10 +1269,10 @@ def _float8_dynamic_activation_int4_weight_transform( "applying int8 weight only quant requires module to have weight attribute" + " but {module} does not have one" ) - packing_format = config.packing_format + int4_packing_format = config.int4_packing_format - assert packing_format == "preshuffled", ( - f"only preshuffled packing_format supported right now, got: {packing_format}" + assert int4_packing_format == "preshuffled", ( + f"only preshuffled int4_packing_format supported right now, got: {int4_packing_format}" ) weight = module.weight group_size = 128 diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 94d45917b9..bedb8b5986 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -16,7 +16,7 @@ class PackingFormat(str, Enum): """ plain means the format that quantized Tensor data lays out elements in Tensor sequentially, - for example: for a Tensor of shape (4, 6): + for example: for a Tensor of shape (4, 6): a_0_0, a_0_1, ..., a_0_5, ... a_3_0, a_3_1, ..., a_3_5 @@ -26,32 +26,11 @@ class PackingFormat(str, Enum): """ PLAIN = "plain" - """ - preshuffled is referring to the preshuffled format used by fbgemm kernels - """ - PRESHUFFLED = "preshuffled" - - """ - marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization - """ - MARLIN_SPARSE = "marlin_sparse" - """ Unpacked to int8 means the subbyte quantized data is stored as int8 """ UNPACKED_TO_INT8 = "unpacked_to_int8" - """ - plain_int32 is referring to the format used by int4 weight-only quantization. - which is a groupwise quantization format 2*int4 is store in a byte and 4*(int4*2) is stored in a int32. - """ - PLAIN_INT32 = "plain_int32" - - """ - tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization - """ - TILE_PACKED_TO_4D = "tile_packed_to_4d" - """ Opaque packing format that's used for tensors that does not have a predefined packing format (that may be decided on hardware, tensor shape, library availability etc.) and it's not diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4a762adc5d..0459d230f0 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -9,6 +9,7 @@ from .int4.int4_opaque_tensor import ( Int4OpaqueTensor, ) +from .int4.int4_packing_format import Int4PackingFormat from .int4.int4_plain_int32_tensor import ( Int4PlainInt32Tensor, ) @@ -39,4 +40,5 @@ "IntxUnpackedTensor", "IntxUnpackedToInt8Tensor", "Int4ChooseQParamsAlgorithm", + "Int4PackingFormat", ] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py b/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py new file mode 100644 index 0000000000..b5d988ef4a --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class Int4PackingFormat(str, Enum): + """Packing format for quantized data in Int4 Tensor subclasses in torchao, represents how + the values in quantized data are packed and laid out in memory. + """ + + """ + plain means the format that quantized Tensor data lays out elements in Tensor sequentially, + for example: for a Tensor of shape (4, 6): + a_0_0, a_0_1, ..., a_0_5, + ... + a_3_0, a_3_1, ..., a_3_5 + + For example for int4, we will + pack two adjacent int4 elements into one uint8/int8 value for plain packing format + """ + PLAIN = "plain" + + """ + preshuffled is referring to the preshuffled format used by fbgemm kernels + """ + PRESHUFFLED = "preshuffled" + + """ + marlin_sparse is referring to the format used by marlin kernels, requires symmetric quantization + """ + MARLIN_SPARSE = "marlin_sparse" + + """ + plain_int32 is a format that 2 adjacent int4 values are packed in a byte and 4 such packed bytes are stored in a int32 value. + """ + PLAIN_INT32 = "plain_int32" + + """ + tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization + for a Tensor of shape (n, k), the packed weight will have dimension: + [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2], where inner_k_tiles is 8 currently + for simplication of Int4TilePackedTo4dTensor API + """ + TILE_PACKED_TO_4D = "tile_packed_to_4d" + + """ + Opaque packing format that's used for tensors that does not have a predefined packing format + (that may be decided on hardware, tensor shape, library availability etc.) and it's not + needed for the rest of the system to understand the specific format that's adopted. + """ + OPAQUE = "opaque"