diff --git a/docs/source/api_ref_qat.rst b/docs/source/api_ref_qat.rst index e0cacab667..7bf1961a23 100644 --- a/docs/source/api_ref_qat.rst +++ b/docs/source/api_ref_qat.rst @@ -42,8 +42,6 @@ Legacy QAT APIs :toctree: generated/ :nosignatures: - IntXQuantizationAwareTrainingConfig - FromIntXQuantizationAwareTrainingConfig Int4WeightOnlyQATQuantizer linear.Int4WeightOnlyQATLinear Int8DynActInt4WeightQATQuantizer diff --git a/test/prototype/test_embedding.py b/test/prototype/test_embedding.py index 7d020920a7..dc65eec24f 100644 --- a/test/prototype/test_embedding.py +++ b/test/prototype/test_embedding.py @@ -18,10 +18,9 @@ ) from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import ( - FromIntXQuantizationAwareTrainingConfig, Int4WeightOnlyEmbeddingQATQuantizer, IntxFakeQuantizeConfig, - IntXQuantizationAwareTrainingConfig, + QATConfig, ) from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, @@ -257,7 +256,7 @@ def test_identical_to_IntxWeightOnlyConfig( ], name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) - def test_identical_to_IntXQuantizationAwareTrainingConfig( + def test_identical_to_QATConfig( self, weight_dtype, granularity, mapping_type, scale_dtype, model_dtype ): # ASYMMETRIC in QAT is very different that PTQ configs @@ -288,12 +287,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( ) quantize_( model, - IntXQuantizationAwareTrainingConfig(weight_config=weight_config), + QATConfig(weight_config=weight_config, step="prepare"), embedding_filter, ) prepared_out = model(indices) - quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter) + quantize_(model, QATConfig(step="convert"), embedding_filter) quantize_( model, IntxWeightOnlyConfig( @@ -355,7 +354,7 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer( prepared_out = model(indices) # Convert model method 1 - quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter) + quantize_(model, QATConfig(step="convert"), embedding_filter) quantize_( model, IntxWeightOnlyConfig( diff --git a/test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py b/test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py index 224e745ac4..c5dd4dba07 100644 --- a/test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py +++ b/test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py @@ -20,10 +20,9 @@ from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import ( - FromIntXQuantizationAwareTrainingConfig, Int8DynActInt4WeightQATQuantizer, IntxFakeQuantizeConfig, - IntXQuantizationAwareTrainingConfig, + QATConfig, ) from torchao.quantization.quant_api import ( Int8DynamicActivationInt4WeightConfig, @@ -499,7 +498,7 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig( for model_dtype in [torch.float32, torch.bfloat16, torch.float16] ], ) - def test_identical_to_IntXQuantizationAwareTrainingConfig( + def test_identical_to_QATConfig( self, weight_dtype, group_size, @@ -545,7 +544,11 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( quantize_( model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), + QATConfig( + activation_config=activation_config, + weight_config=weight_config, + step="prepare", + ), ) try: prepared_out = model(activations) @@ -555,7 +558,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( return raise e - quantize_(model, FromIntXQuantizationAwareTrainingConfig()) + quantize_(model, QATConfig(step="convert")) quantize_( model, Int8DynamicActivationIntxWeightConfig( @@ -606,7 +609,7 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer( prepared_out = model(activations) # Convert model method 1 - quantize_(model, FromIntXQuantizationAwareTrainingConfig()) + quantize_(model, QATConfig(step="convert")) quantize_( model, Int8DynamicActivationIntxWeightConfig( diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f523cb091c..03de34a8d7 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -9,7 +9,6 @@ import copy import unittest -import warnings from typing import List, Type import torch @@ -39,8 +38,6 @@ ) from torchao.quantization.qat.api import ( ComposableQATQuantizer, - FromIntXQuantizationAwareTrainingConfig, - IntXQuantizationAwareTrainingConfig, QATConfig, QATStep, initialize_fake_quantizers, @@ -1718,95 +1715,6 @@ def test_qat_fp8a4w_quantizer(self): self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) self.assertFalse(torch.equal(new_weight, prev_weight)) - def test_legacy_quantize_api_e2e(self): - """ - Test that the following two APIs are numerically equivalent: - - New API: - quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) - quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) - - Old API: - quantize_(model, IntXQuantizationAwareTrainingConfig(...)) - quantize_(model, FromIntXQuantizationAwareTrainingConfig()) - quantize_(model, Int8DynamicActivationInt4WeightConfig()) - """ - group_size = 16 - torch.manual_seed(self.SEED) - m = M() - baseline_model = copy.deepcopy(m) - - # Baseline prepare - act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) - old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config) - quantize_(baseline_model, old_qat_config) - - # QATConfig prepare - base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size) - quantize_(m, QATConfig(base_config, step="prepare")) - - # Compare prepared values - torch.manual_seed(self.SEED) - x = m.example_inputs() - x2 = copy.deepcopy(x) - out = m(*x) - baseline_out = baseline_model(*x2) - torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - - # Baseline convert - quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig()) - quantize_(baseline_model, base_config) - - # quantize_ convert - quantize_(m, QATConfig(base_config, step="convert")) - - # Compare converted values - torch.manual_seed(self.SEED) - x = m.example_inputs() - x2 = copy.deepcopy(x) - out = m(*x) - baseline_out = baseline_model(*x2) - torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - - def test_qat_api_deprecation(self): - """ - Test that the appropriate deprecation warning is logged exactly once per class. - """ - from torchao.quantization.qat import ( - FakeQuantizeConfig, - FakeQuantizer, - from_intx_quantization_aware_training, - intx_quantization_aware_training, - ) - - # Reset deprecation warning state, otherwise we won't log warnings here - warnings.resetwarnings() - - # Map from deprecated API to the args needed to instantiate it - deprecated_apis_to_args = { - IntXQuantizationAwareTrainingConfig: (), - FromIntXQuantizationAwareTrainingConfig: (), - intx_quantization_aware_training: (), - from_intx_quantization_aware_training: (), - FakeQuantizeConfig: (torch.int8, "per_channel"), - FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),), - } - - with warnings.catch_warnings(record=True) as _warnings: - # Call each deprecated API twice - for cls, args in deprecated_apis_to_args.items(): - cls(*args) - cls(*args) - - # Each call should trigger the warning only once - self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) - for w in _warnings: - self.assertIn( - "is deprecated and will be removed in a future release", - str(w.message), - ) - def test_qat_api_convert_no_quantization(self): """ Test that `QATConfig(step="convert")` swaps back to nn modules without quantization. diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index aa19aa1890..3474fbd668 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -64,7 +64,6 @@ PlainLayout, TensorCoreTiledLayout, UIntXWeightOnlyConfig, - intx_quantization_aware_training, quantize_, swap_conv2d_1x1_to_linear, ) @@ -119,7 +118,6 @@ "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", - "intx_quantization_aware_training", "swap_conv2d_1x1_to_linear", "Int4DynamicActivationInt4WeightConfig", "Int8DynamicActivationInt4WeightConfig", diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index c2f1d6f8d7..220d014ea8 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -1,6 +1,6 @@ from torchao.quantization.qat.api import ( ComposableQATQuantizer, - FakeQuantizeConfig, + IntxFakeQuantizeConfig as FakeQuantizeConfig, ) __all__ = [ diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 4218c763e2..9c38c141f5 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -1,25 +1,19 @@ from .api import ( ComposableQATQuantizer, - FromIntXQuantizationAwareTrainingConfig, - IntXQuantizationAwareTrainingConfig, QATConfig, QATStep, - from_intx_quantization_aware_training, initialize_fake_quantizers, - intx_quantization_aware_training, ) from .embedding import ( FakeQuantizedEmbedding, Int4WeightOnlyEmbeddingQATQuantizer, ) from .fake_quantize_config import ( - FakeQuantizeConfig, FakeQuantizeConfigBase, Float8FakeQuantizeConfig, IntxFakeQuantizeConfig, ) from .fake_quantizer import ( - FakeQuantizer, FakeQuantizerBase, Float8FakeQuantizer, IntxFakeQuantizer, @@ -50,11 +44,4 @@ "Int4WeightOnlyEmbeddingQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", - # for BC - "FakeQuantizer", - "FakeQuantizeConfig", - "from_intx_quantization_aware_training", - "FromIntXQuantizationAwareTrainingConfig", - "intx_quantization_aware_training", - "IntXQuantizationAwareTrainingConfig", ] diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 551a6d5da0..9759883af6 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -21,13 +21,11 @@ from .embedding import FakeQuantizedEmbedding from .fake_quantize_config import ( - FakeQuantizeConfig, # noqa: F401, for BC FakeQuantizeConfigBase, IntxFakeQuantizeConfig, _infer_fake_quantize_configs, ) from .linear import FakeQuantizedLinear -from .utils import _log_deprecation_warning class QATStep(str, Enum): @@ -288,119 +286,6 @@ def _qat_config_transform( return module -@dataclass -class IntXQuantizationAwareTrainingConfig(AOBaseConfig): - """ - (Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead. - - Config for applying fake quantization to a `torch.nn.Module`. - to be used with :func:`~torchao.quantization.quant_api.quantize_`. - - Example usage:: - - from torchao.quantization import quantize_ - from torchao.quantization.qat import IntxFakeQuantizeConfig - activation_config = IntxFakeQuantizeConfig( - torch.int8, "per_token", is_symmetric=False, - ) - weight_config = IntxFakeQuantizeConfig( - torch.int4, group_size=32, is_symmetric=True, - ) - quantize_( - model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), - ) - - Note: If the config is applied on a module that is not - `torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on - `torch.nn.Embedding` with an activation config, then we will raise - ValueError as these are not supported. - """ - - activation_config: Optional[FakeQuantizeConfigBase] = None - weight_config: Optional[FakeQuantizeConfigBase] = None - - def __post_init__(self): - _log_deprecation_warning(self) - - -# for BC -class intx_quantization_aware_training(IntXQuantizationAwareTrainingConfig): - pass - - -@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) -def _intx_quantization_aware_training_transform( - module: torch.nn.Module, - config: IntXQuantizationAwareTrainingConfig, -) -> torch.nn.Module: - mod = module - activation_config = config.activation_config - weight_config = config.weight_config - - if isinstance(mod, torch.nn.Linear): - return FakeQuantizedLinear.from_linear( - mod, - activation_config, - weight_config, - ) - elif isinstance(mod, torch.nn.Embedding): - if activation_config is not None: - raise ValueError( - "Activation fake quantization is not supported for embedding" - ) - return FakeQuantizedEmbedding.from_embedding(mod, weight_config) - else: - raise ValueError("Module of type '%s' does not have QAT support" % type(mod)) - - -@dataclass -class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): - """ - (Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead. - - Config for converting a model with fake quantized modules, - such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` - and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, - back to model with the original, corresponding modules without - fake quantization. This should be used with - :func:`~torchao.quantization.quant_api.quantize_`. - - Example usage:: - - from torchao.quantization import quantize_ - quantize_( - model_with_fake_quantized_linears, - FromIntXQuantizationAwareTrainingConfig(), - ) - """ - - def __post_init__(self): - _log_deprecation_warning(self) - - -# for BC -class from_intx_quantization_aware_training(FromIntXQuantizationAwareTrainingConfig): - pass - - -@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig) -def _from_intx_quantization_aware_training_transform( - mod: torch.nn.Module, - config: FromIntXQuantizationAwareTrainingConfig, -) -> torch.nn.Module: - """ - If the given module is a fake quantized module, return the original - corresponding version of the module without fake quantization. - """ - if isinstance(mod, FakeQuantizedLinear): - return mod.to_linear() - elif isinstance(mod, FakeQuantizedEmbedding): - return mod.to_embedding() - else: - return mod - - class ComposableQATQuantizer(TwoStepQuantizer): """ Composable quantizer that users can use to apply multiple QAT quantizers easily. diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 3a1c7c78f1..e4131bc5ee 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -34,8 +34,6 @@ from torchao.quantization.quantize_.workflows import Int4PackingFormat from torchao.utils import _is_float8_type -from .utils import _log_deprecation_warning - class FakeQuantizeConfigBase(abc.ABC): """ @@ -201,14 +199,6 @@ def __init__( if is_dynamic and range_learning: raise ValueError("`is_dynamic` is not compatible with `range_learning`") - self.__post_init__() - - def __post_init__(self): - """ - For deprecation only, can remove after https://github.com/pytorch/ao/issues/2630. - """ - pass - def _get_granularity( self, granularity: Union[Granularity, str, None], @@ -334,16 +324,6 @@ def __setattr__(self, name: str, value: Any): super().__setattr__(name, value) -# For BC -class FakeQuantizeConfig(IntxFakeQuantizeConfig): - """ - (Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizeConfig` instead. - """ - - def __post_init__(self): - _log_deprecation_warning(self) - - def _infer_fake_quantize_configs( base_config: AOBaseConfig, ) -> Tuple[Optional[FakeQuantizeConfigBase], Optional[FakeQuantizeConfigBase]]: diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 595dafaba8..e18fb8b5b5 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -41,7 +41,6 @@ from .utils import ( _fake_quantize_per_channel_group, _fake_quantize_per_token, - _log_deprecation_warning, ) @@ -330,14 +329,3 @@ def _maybe_update_qparams_for_range_learning(self) -> None: zero_point = _Round.apply(zero_point) zero_point = torch.clamp(zero_point, qmin, qmax) self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True) - - -# For BC -class FakeQuantizer(IntxFakeQuantizer): - """ - (Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizer` instead. - """ - - def __init__(self, config: FakeQuantizeConfigBase): - super().__init__(config) - _log_deprecation_warning(self) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 139b14cf3f..f9170702f1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -120,9 +120,6 @@ Int4WeightOnlyQuantizer, Int8DynActInt4WeightQuantizer, ) -from .qat import ( - intx_quantization_aware_training, -) from .quant_primitives import ( _DTYPE_TO_QVALUE_BOUNDS, MappingType, @@ -147,7 +144,6 @@ "autoquant", "_get_subclass_inserter", "quantize_", - "intx_quantization_aware_training", "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig",