diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 004860e329..3b57acc9a3 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -73,6 +73,8 @@ Float8DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, ) from torchao.quantization.quant_primitives import ( MappingType, @@ -1872,6 +1874,8 @@ def _test_quantize_api_against_ptq( base_config: AOBaseConfig, target_prepare_sqnr: float, target_convert_sqnr: float, + dtype: torch.dtype = torch.bfloat16, + module_type: str = "linear", ): """ Test the following: @@ -1884,22 +1888,32 @@ def _test_quantize_api_against_ptq( quantize_(model, base_config) """ torch.manual_seed(self.SEED) - m = M().to(torch.bfloat16).cuda() - example_inputs = (m.example_inputs()[0].to(torch.bfloat16).cuda(),) + + if module_type == "linear": + m = M().to(dtype).cuda() + example_inputs = (m.example_inputs()[0].to(dtype).cuda(),) + filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear) + elif module_type == "embedding": + m = M3().to(dtype).cuda() + example_inputs = (m.example_inputs()[0].cuda(),) + filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding) + else: + raise ValueError(f"Unknown module type {module_type}") # baseline m_baseline = copy.deepcopy(m) - quantize_(m_baseline, base_config) + quantize_(m_baseline, base_config, filter_fn) out_baseline = m_baseline(*example_inputs) # compare prepare - quantize_(m, QATConfig(base_config, step="prepare")) + quantize_(m, QATConfig(base_config, step="prepare"), filter_fn) out_prepared = m(*example_inputs) prepare_sqnr = compute_error(out_prepared, out_baseline) + self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr) # compare convert - quantize_(m, QATConfig(base_config, step="convert")) + quantize_(m, QATConfig(base_config, step="convert"), filter_fn) out_converted = m(*example_inputs) convert_sqnr = compute_error(out_converted, out_baseline) self.assertGreaterEqual(convert_sqnr, target_convert_sqnr) @@ -1967,6 +1981,56 @@ def test_quantize_api_int8_int4(self): target_convert_sqnr=float("inf"), ) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @parametrize( + "weight_dtype, weight_granularity, dtype", + [ + (weight_dtype, weight_granularity, dtype) + for weight_dtype in [getattr(torch, f"int{i}") for i in range(2, 9)] + for weight_granularity in [PerGroup(32), PerAxis(0)] + for dtype in [torch.bfloat16, torch.float32] + ], + ) + def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype): + """ + Test the following: + quantize_(model, QATConfig(Int8DynamicActivationIntxWeightConfig(), step="prepare")) + quantize_(model, QATConfig(Int8DynamicActivationIntxWeightConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, weight_granularity=weight_granularity + ), + target_prepare_sqnr=float("inf"), + target_convert_sqnr=float("inf"), + dtype=dtype, + ) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @parametrize( + "weight_dtype, granularity, dtype, module_type", + [ + (weight_dtype, granularity, dtype, module_type) + for weight_dtype in [getattr(torch, f"int{i}") for i in range(2, 9)] + for granularity in [PerGroup(32), PerAxis(0)] + for dtype in [torch.bfloat16, torch.float32] + for module_type in ["linear", "embedding"] + ], + ) + def test_quantize_api_intx(self, weight_dtype, granularity, dtype, module_type): + """ + Test the following: + quantize_(model, QATConfig(IntxWeightOnlyConfig(), step="prepare")) + quantize_(model, QATConfig(IntxWeightOnlyConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + target_prepare_sqnr=float("inf"), + target_convert_sqnr=float("inf"), + dtype=dtype, + module_type=module_type, + ) + def test_infer_fp8_int4_config(self): """ Test that fake quantize configs are correctly inferred from diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index dc86aa919f..c78b7ab3ae 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -363,6 +363,8 @@ def _infer_fake_quantize_configs( Float8DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, ) if isinstance(base_config, Int8DynamicActivationInt4WeightConfig): @@ -438,6 +440,54 @@ def _infer_fake_quantize_configs( else: act_config = None weight_config = NVFP4FakeQuantizeConfig(False) + elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig): + assert base_config.version >= 2, "Only version 2+ is supported" + assert base_config.intx_packing_format == "unpacked_to_int8", ( + "Only unpacked_to_int8 is supported" + ) + assert base_config.weight_dtype != torch.int1, "Only int2+ is supported" + assert base_config.act_mapping_type == MappingType.ASYMMETRIC, ( + "Only asymmetric activation mapping is supported" + ) + assert base_config.weight_mapping_type == MappingType.SYMMETRIC, ( + "Only symmetric weight mapping is supported" + ) + assert base_config.weight_scale_dtype is None, ( + "Specifying weight_scale_dtype is not supported" + ) + + act_config = IntxFakeQuantizeConfig( + torch.int8, + "per_token", + is_symmetric=False, + scale_precision=base_config.weight_scale_dtype, + ) + weight_config = IntxFakeQuantizeConfig( + dtype=base_config.weight_dtype, + granularity=base_config.weight_granularity, + mapping_type=base_config.weight_mapping_type, + scale_precision=base_config.weight_scale_dtype, + ) + elif isinstance(base_config, IntxWeightOnlyConfig): + assert base_config.version >= 2, "Only version 2+ is supported" + assert base_config.intx_packing_format == "unpacked_to_int8", ( + "Only unpacked_to_int8 is supported" + ) + assert base_config.mapping_type == MappingType.SYMMETRIC, ( + "Only symmetric mapping is supported" + ) + assert base_config.weight_dtype != torch.int1, "Only int2+ is supported" + assert base_config.scale_dtype is None, ( + "Specifying scale_dtype is not supported" + ) + + act_config = None + weight_config = IntxFakeQuantizeConfig( + dtype=base_config.weight_dtype, + granularity=base_config.granularity, + mapping_type=base_config.mapping_type, + scale_precision=base_config.scale_dtype, + ) else: raise ValueError("Unexpected base config: %s" % base_config) return (act_config, weight_config)