diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 0a7a94af1c..94245e17d1 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -50,7 +50,7 @@ ) from torchao.quantization.qat.fake_quantize_config import ( Float8FakeQuantizeConfig, - Int4WeightPreshuffledFakeQuantizeConfig, + Int4WeightFakeQuantizeConfig, IntxFakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( @@ -1985,7 +1985,7 @@ def test_infer_fp8_int4_config(self): self.assertIsInstance(act_config, Float8FakeQuantizeConfig) self.assertEqual(act_config.dtype, e4m3_dtype) self.assertIsInstance(act_config.granularity, PerRow) - self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig) + self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig) self.assertEqual(weight_config.group_size, 128) self.assertEqual(weight_config.activation_dtype, e4m3_dtype) @@ -2008,7 +2008,7 @@ def test_infer_int4_weight_only_config(self): base_config = Int4WeightOnlyConfig(version=2) (act_config, weight_config) = _infer_fake_quantize_configs(base_config) self.assertIsNone(act_config) - self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig) + self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig) self.assertEqual(weight_config.group_size, 128) self.assertEqual(weight_config.activation_dtype, torch.bfloat16) @@ -2102,7 +2102,7 @@ def test_fbgemm_fp8_int4_preshuffled_primitives(self): """ Compare numerics between: (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle - (2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer` + (2) Our reference QAT version in `Int4WeightFakeQuantizer` """ from fbgemm_gpu.experimental.gen_ai.quantize import ( int4_row_quantize, @@ -2184,7 +2184,7 @@ def test_fbgemm_int4_weight_only_primitives(self): """ Compare numerics between: (1) fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize_zp - (2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer` + (2) Our reference QAT version in `Int4WeightFakeQuantizer` """ from fbgemm_gpu.experimental.gen_ai.quantize import ( int4_row_quantize_zp, diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 892fcd8d8b..c39aa4c817 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -78,9 +78,8 @@ def __post_init__(self): ) -# TODO: rename this config, it actually works for both plain and preshuffled @dataclass -class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase): +class Int4WeightFakeQuantizeConfig(FakeQuantizeConfigBase): """ Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel: torch.ops.fbgemm.f8i4bf16_shuffled @@ -393,7 +392,7 @@ def _infer_fake_quantize_configs( raise ValueError( f"Packing format must be one of {supported_packing_formats}" ) - weight_config = Int4WeightPreshuffledFakeQuantizeConfig( + weight_config = Int4WeightFakeQuantizeConfig( group_size=128, activation_dtype=torch.bfloat16, ) @@ -436,7 +435,7 @@ def _infer_fake_quantize_configs( dtype=e4m3_dtype, granularity=PerRow(), ) - weight_config = Int4WeightPreshuffledFakeQuantizeConfig( + weight_config = Int4WeightFakeQuantizeConfig( group_size=128, activation_dtype=e4m3_dtype, ) diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 8c21ecf5cc..09e3fa1e59 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -35,7 +35,7 @@ from .fake_quantize_config import ( FakeQuantizeConfigBase, Float8FakeQuantizeConfig, - Int4WeightPreshuffledFakeQuantizeConfig, + Int4WeightFakeQuantizeConfig, IntxFakeQuantizeConfig, ) from .utils import ( @@ -68,8 +68,8 @@ def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase": if isinstance(config, IntxFakeQuantizeConfig): return IntxFakeQuantizer(config) - elif isinstance(config, Int4WeightPreshuffledFakeQuantizeConfig): - return Int4WeightPreshuffledFakeQuantizer(config) + elif isinstance(config, Int4WeightFakeQuantizeConfig): + return Int4WeightFakeQuantizer(config) elif isinstance(config, Float8FakeQuantizeConfig): return Float8FakeQuantizer(config) elif isinstance(config, NVFP4FakeQuantizeConfig): @@ -103,8 +103,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return dq -# TODO: rename this, it also works for plain Int4Tensor -class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase): +class Int4WeightFakeQuantizer(FakeQuantizerBase): """ Generic module for applying int4 fake quantization to a weight tensor, targeting the following FBGEMM kernels: @@ -113,12 +112,10 @@ class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase): torch.ops.fbgemm.bf16i4bf16_rowwise """ - def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig): + def __init__(self, config: Int4WeightFakeQuantizeConfig): super().__init__() self.config = config - torch._C._log_api_usage_once( - "torchao.quantization.qat.Int4WeightPreshuffledFakeQuantizer" - ) + torch._C._log_api_usage_once("torchao.quantization.qat.Int4WeightFakeQuantizer") def forward(self, w: torch.Tensor) -> torch.Tensor: if self.config.activation_dtype == torch.float8_e4m3fn: