diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 506cec9dea..e530babdb9 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1178,6 +1178,13 @@ def __init__(self): assert isinstance(m.nested.linear.weight, AffineQuantizedTensor) assert isinstance(m.linear1.weight, AffineQuantizedTensor) + def test_fqn_config_module_config_and_fqn_config_both_specified(self): + with self.assertRaises(ValueError): + FqnToConfig( + fqn_to_config={"test": Float8WeightOnlyConfig()}, + module_fqn_to_config={"test2": Float8WeightOnlyConfig()}, + ) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 09c2edcd9f..425af1feb9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2466,6 +2466,15 @@ class FqnToConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") + if ( + len(self.fqn_to_config) > 0 + and len(self.module_fqn_to_config) > 0 + and self.fqn_to_config != self.module_fqn_to_config + ): + raise ValueError( + "`fqn_to_config` and `module_fqn_to_config` are both specified and are not equal!" + ) + # This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object. if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) == 0: self.fqn_to_config = self.module_fqn_to_config @@ -2479,6 +2488,18 @@ def __post_init__(self): "Config Deprecation: _default is deprecated and will no longer be supported in a future release. Please see https://github.com/pytorch/ao/issues/3229 for more details." ) + def __str__(self): + return "\n".join( + [ + "FqnToConfig({", + *( + f" '{key}':\n {value}," + for key, value in self.fqn_to_config.items() + ), + "})", + ] + ) + # maintain BC ModuleFqnToConfig = FqnToConfig