From 0a953dae58eeb61ba62b8aec6cbd5514068fcf45 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Jan 2025 13:36:59 -0800 Subject: [PATCH] Fix torch.intx support in FakeQuantizeConfig **Summary:** Fixes the following error when passing `torch.intx` to `FakeQuantizeConfig`. These dtypes were introduced in PyTorch 2.6+: ``` ValueError: Unsupported dtype 'torch.int4', choose from [torch.int8, torch.uint8, , , , , , , , torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7] ``` **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_torch_intx --- test/quantization/test_qat.py | 21 +++++++++++++++++++ torchao/quantization/quant_primitives.py | 26 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 642f0bd4ad..8a78b8b387 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -63,6 +63,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_6, ) # TODO: put this in a common test utils file @@ -1327,6 +1328,26 @@ def test_quantize_api_convert_path(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" + ) + def test_fake_quantize_config_torch_intx(self): + """ + Test that `FakeQuantizeConfig` works with torch.intx. + """ + group_size = 16 + config1 = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + config2 = FakeQuantizeConfig(torch.int4, group_size=group_size) + linear1 = FakeQuantizedLinear(32, 64, weight_config=config1) + linear2 = FakeQuantizedLinear(32, 64, weight_config=config2) + linear2.weight = linear1.weight + torch.manual_seed(self.SEED) + x = torch.randn((1, 32)).to(torch.float) + x2 = copy.deepcopy(x) + out1 = linear1(*x) + out2 = linear2(*x2) + torch.testing.assert_close(out1, out2, atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index fddd21c43e..e587d4bc2b 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -18,6 +18,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, _is_float8_type, _register_custom_op, ) @@ -162,6 +163,31 @@ class TorchAODType(Enum): } ) +# torch.intX available only in PyTorch 2.6+ +if TORCH_VERSION_AT_LEAST_2_6: + _SUB_BYTE_INT_BOUNDS.update( + { + torch.int1: (-(2**0), 2**0 - 1), + torch.int2: (-(2**1), 2**1 - 1), + torch.int3: (-(2**2), 2**2 - 1), + torch.int4: (-(2**3), 2**3 - 1), + torch.int5: (-(2**4), 2**4 - 1), + torch.int6: (-(2**5), 2**5 - 1), + torch.int7: (-(2**6), 2**6 - 1), + } + ) + _DTYPE_TO_BIT_WIDTH.update( + { + torch.int1: 1, + torch.int2: 2, + torch.int3: 3, + torch.int4: 4, + torch.int5: 5, + torch.int6: 6, + torch.int7: 7, + } + ) + _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS) assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys()