diff --git a/README.md b/README.md index b3b922f1d2..ad3e0b6f97 100644 --- a/README.md +++ b/README.md @@ -261,7 +261,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow -1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))` +1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` 2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 7a5eaa725f..577ca6789a 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -10,6 +10,7 @@ import gc import tempfile import unittest +import warnings from pathlib import Path import torch @@ -786,6 +787,56 @@ def test_int4wo_cuda_serialization(self): # load state_dict in cuda model.load_state_dict(sd, assign=True) + def test_config_deprecation(self): + """ + Test that old config functions like `int4_weight_only` trigger deprecation warnings. + """ + from torchao.quantization import ( + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, + fpx_weight_only, + gemlite_uintx_weight_only, + int4_dynamic_activation_int4_weight, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, + uintx_weight_only, + ) + + # Reset deprecation warning state, otherwise we won't log warnings here + warnings.resetwarnings() + + # Map from deprecated API to the args needed to instantiate it + deprecated_apis_to_args = { + float8_dynamic_activation_float8_weight: (), + float8_static_activation_float8_weight: (torch.randn(3)), + float8_weight_only: (), + fpx_weight_only: (3, 2), + gemlite_uintx_weight_only: (), + int4_dynamic_activation_int4_weight: (), + int4_weight_only: (), + int8_dynamic_activation_int4_weight: (), + int8_dynamic_activation_int8_weight: (), + int8_weight_only: (), + uintx_weight_only: (torch.uint4,), + } + + with warnings.catch_warnings(record=True) as _warnings: + # Call each deprecated API twice + for cls, args in deprecated_apis_to_args.items(): + cls(*args) + cls(*args) + + # Each call should trigger the warning only once + self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) + for w in _warnings: + self.assertIn( + "is deprecated and will be removed in a future release", + str(w.message), + ) + common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 0d72c16687..87e011c57b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -64,9 +64,21 @@ PlainLayout, TensorCoreTiledLayout, UIntXWeightOnlyConfig, + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, + fpx_weight_only, + gemlite_uintx_weight_only, + int4_dynamic_activation_int4_weight, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_semi_sparse_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, intx_quantization_aware_training, quantize_, swap_conv2d_1x1_to_linear, + uintx_weight_only, ) from .quant_primitives import ( MappingType, @@ -117,7 +129,19 @@ "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", + "int4_dynamic_activation_int4_weight", + "int8_dynamic_activation_int4_weight", + "int8_dynamic_activation_int8_weight", + "int8_dynamic_activation_int8_semi_sparse_weight", + "int4_weight_only", + "int8_weight_only", "intx_quantization_aware_training", + "float8_weight_only", + "float8_dynamic_activation_float8_weight", + "float8_static_activation_float8_weight", + "uintx_weight_only", + "fpx_weight_only", + "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", "Int4DynamicActivationInt4WeightConfig", "Int8DynamicActivationInt4WeightConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ac16c37d5d..ae8210a41a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -96,6 +96,7 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( + _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -144,7 +145,18 @@ "autoquant", "_get_subclass_inserter", "quantize_", + "int8_dynamic_activation_int4_weight", + "int8_dynamic_activation_int8_weight", + "int8_dynamic_activation_int8_semi_sparse_weight", + "int4_weight_only", + "int8_weight_only", "intx_quantization_aware_training", + "float8_weight_only", + "uintx_weight_only", + "fpx_weight_only", + "gemlite_uintx_weight_only", + "float8_dynamic_activation_float8_weight", + "float8_static_activation_float8_weight", "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", @@ -491,7 +503,7 @@ def quantize_( # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile - from torchao.quantization.quant_api import Int4WeightOnlyConfig + from torchao.quantization.quant_api import int4_weight_only m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1)) @@ -613,6 +625,12 @@ def __post_init__(self): ) +# for BC +int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( + "int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig +) + + @register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) def _int8_dynamic_activation_int4_weight_transform( module: torch.nn.Module, @@ -978,6 +996,12 @@ def __post_init__(self): ) +# for bc +int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( + "int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig +) + + @register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) def _int4_dynamic_activation_int4_weight_transform( module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig @@ -1035,6 +1059,12 @@ def __post_init__(self): ) +# for BC +gemlite_uintx_weight_only = _ConfigDeprecationWrapper( + "gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig +) + + @register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) def _gemlite_uintx_weight_only_transform( module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig @@ -1112,6 +1142,11 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") +# for BC +# TODO maybe change other callsites +int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig) + + def _int4_weight_only_quantize_tensor(weight, config): # TODO(future PR): perhaps move this logic to a different file, to keep the API # file clean of implementation details @@ -1323,6 +1358,10 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") +# for BC +int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig) + + def _int8_weight_only_quantize_tensor(weight, config): mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -1480,6 +1519,12 @@ def __post_init__(self): ) +# for BC +int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper( + "int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig +) + + def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): layout = config.layout act_mapping_type = config.act_mapping_type @@ -1585,6 +1630,12 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") +# for BC +float8_weight_only = _ConfigDeprecationWrapper( + "float8_weight_only", Float8WeightOnlyConfig +) + + def _float8_weight_only_quant_tensor(weight, config): if config.version == 1: warnings.warn( @@ -1743,6 +1794,12 @@ def __post_init__(self): self.granularity = [activation_granularity, weight_granularity] +# for bc +float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper( + "float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig +) + + def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -1918,6 +1975,12 @@ def __post_init__(self): ) +# for bc +float8_static_activation_float8_weight = _ConfigDeprecationWrapper( + "float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig +) + + @register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) def _float8_static_activation_float8_weight_transform( module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig @@ -2000,6 +2063,12 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig") +# for BC +uintx_weight_only = _ConfigDeprecationWrapper( + "uintx_weight_only", UIntXWeightOnlyConfig +) + + @register_quantize_module_handler(UIntXWeightOnlyConfig) def _uintx_weight_only_transform( module: torch.nn.Module, config: UIntXWeightOnlyConfig @@ -2278,6 +2347,10 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig") +# for BC +fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig) + + @register_quantize_module_handler(FPXWeightOnlyConfig) def _fpx_weight_only_transform( module: torch.nn.Module, config: FPXWeightOnlyConfig diff --git a/torchao/utils.py b/torchao/utils.py index 9dfebfb6fb..5af3e00cfa 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -12,7 +12,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Type import torch import torch.nn.utils.parametrize as parametrize @@ -433,6 +433,25 @@ def __eq__(self, other): TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") +class _ConfigDeprecationWrapper: + """ + A deprecation wrapper that directs users from a deprecated "config function" + (e.g. `int4_weight_only`) to the replacement config class. + """ + + def __init__(self, deprecated_name: str, config_cls: Type): + self.deprecated_name = deprecated_name + self.config_cls = config_cls + + def __call__(self, *args, **kwargs): + warnings.warn( + f"`{self.deprecated_name}` is deprecated and will be removed in a future release. " + f"Please use `{self.config_cls.__name__}` instead. Example usage:\n" + f" quantize_(model, {self.config_cls.__name__}(...))" + ) + return self.config_cls(*args, **kwargs) + + """ Helper function for implementing aten op or torch function dispatch and dispatching to these implementations.