diff --git a/README.md b/README.md index 9330900300..a1d474ae02 100644 --- a/README.md +++ b/README.md @@ -258,7 +258,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow -1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` +1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))` 2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index d7d2b4a5b4..2b3538195e 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -10,7 +10,6 @@ import gc import tempfile import unittest -import warnings from pathlib import Path import torch @@ -847,56 +846,6 @@ def test_int4wo_cuda_serialization(self): # load state_dict in cuda model.load_state_dict(sd, assign=True) - def test_config_deprecation(self): - """ - Test that old config functions like `int4_weight_only` trigger deprecation warnings. - """ - from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, - uintx_weight_only, - ) - - # Reset deprecation warning state, otherwise we won't log warnings here - warnings.resetwarnings() - - # Map from deprecated API to the args needed to instantiate it - deprecated_apis_to_args = { - float8_dynamic_activation_float8_weight: (), - float8_static_activation_float8_weight: (torch.randn(3)), - float8_weight_only: (), - fpx_weight_only: (3, 2), - gemlite_uintx_weight_only: (), - int4_dynamic_activation_int4_weight: (), - int4_weight_only: (), - int8_dynamic_activation_int4_weight: (), - int8_dynamic_activation_int8_weight: (), - int8_weight_only: (), - uintx_weight_only: (torch.uint4,), - } - - with warnings.catch_warnings(record=True) as _warnings: - # Call each deprecated API twice - for cls, args in deprecated_apis_to_args.items(): - cls(*args) - cls(*args) - - # Each call should trigger the warning only once - self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) - for w in _warnings: - self.assertIn( - "is deprecated and will be removed in a future release", - str(w.message), - ) - common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index c8774e9426..aa19aa1890 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -64,21 +64,9 @@ PlainLayout, TensorCoreTiledLayout, UIntXWeightOnlyConfig, - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_semi_sparse_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, intx_quantization_aware_training, quantize_, swap_conv2d_1x1_to_linear, - uintx_weight_only, ) from .quant_primitives import ( MappingType, @@ -131,19 +119,7 @@ "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", - "int4_dynamic_activation_int4_weight", - "int8_dynamic_activation_int4_weight", - "int8_dynamic_activation_int8_weight", - "int8_dynamic_activation_int8_semi_sparse_weight", - "int4_weight_only", - "int8_weight_only", "intx_quantization_aware_training", - "float8_weight_only", - "float8_dynamic_activation_float8_weight", - "float8_static_activation_float8_weight", - "uintx_weight_only", - "fpx_weight_only", - "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", "Int4DynamicActivationInt4WeightConfig", "Int8DynamicActivationInt4WeightConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 70484f1e5b..139b14cf3f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -96,7 +96,6 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( - _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -148,18 +147,7 @@ "autoquant", "_get_subclass_inserter", "quantize_", - "int8_dynamic_activation_int4_weight", - "int8_dynamic_activation_int8_weight", - "int8_dynamic_activation_int8_semi_sparse_weight", - "int4_weight_only", - "int8_weight_only", "intx_quantization_aware_training", - "float8_weight_only", - "uintx_weight_only", - "fpx_weight_only", - "gemlite_uintx_weight_only", - "float8_dynamic_activation_float8_weight", - "float8_static_activation_float8_weight", "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", @@ -507,7 +495,7 @@ def quantize_( # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile - from torchao.quantization.quant_api import int4_weight_only + from torchao.quantization.quant_api import Int4WeightOnlyConfig m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1)) @@ -629,12 +617,6 @@ def __post_init__(self): ) -# for BC -int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( - "int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig -) - - @register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) def _int8_dynamic_activation_int4_weight_transform( module: torch.nn.Module, @@ -1000,12 +982,6 @@ def __post_init__(self): ) -# for bc -int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( - "int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig -) - - @register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) def _int4_dynamic_activation_int4_weight_transform( module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig @@ -1063,12 +1039,6 @@ def __post_init__(self): ) -# for BC -gemlite_uintx_weight_only = _ConfigDeprecationWrapper( - "gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig -) - - @register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) def _gemlite_uintx_weight_only_transform( module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig @@ -1146,11 +1116,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") -# for BC -# TODO maybe change other callsites -int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig) - - def _int4_weight_only_quantize_tensor(weight, config): # TODO(future PR): perhaps move this logic to a different file, to keep the API # file clean of implementation details @@ -1362,10 +1327,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") -# for BC -int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig) - - def _int8_weight_only_quantize_tensor(weight, config): mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -1523,12 +1484,6 @@ def __post_init__(self): ) -# for BC -int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper( - "int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig -) - - def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): layout = config.layout act_mapping_type = config.act_mapping_type @@ -1634,12 +1589,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") -# for BC -float8_weight_only = _ConfigDeprecationWrapper( - "float8_weight_only", Float8WeightOnlyConfig -) - - def _float8_weight_only_quant_tensor(weight, config): if config.version == 1: warnings.warn( @@ -1798,12 +1747,6 @@ def __post_init__(self): self.granularity = [activation_granularity, weight_granularity] -# for bc -float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper( - "float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig -) - - def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -1979,12 +1922,6 @@ def __post_init__(self): ) -# for bc -float8_static_activation_float8_weight = _ConfigDeprecationWrapper( - "float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig -) - - @register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) def _float8_static_activation_float8_weight_transform( module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig @@ -2067,12 +2004,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig") -# for BC -uintx_weight_only = _ConfigDeprecationWrapper( - "uintx_weight_only", UIntXWeightOnlyConfig -) - - @register_quantize_module_handler(UIntXWeightOnlyConfig) def _uintx_weight_only_transform( module: torch.nn.Module, config: UIntXWeightOnlyConfig @@ -2351,10 +2282,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig") -# for BC -fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig) - - @register_quantize_module_handler(FPXWeightOnlyConfig) def _fpx_weight_only_transform( module: torch.nn.Module, config: FPXWeightOnlyConfig diff --git a/torchao/utils.py b/torchao/utils.py index 5af3e00cfa..9dfebfb6fb 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -12,7 +12,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable, Optional, Type +from typing import Any, Callable, Optional import torch import torch.nn.utils.parametrize as parametrize @@ -433,25 +433,6 @@ def __eq__(self, other): TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") -class _ConfigDeprecationWrapper: - """ - A deprecation wrapper that directs users from a deprecated "config function" - (e.g. `int4_weight_only`) to the replacement config class. - """ - - def __init__(self, deprecated_name: str, config_cls: Type): - self.deprecated_name = deprecated_name - self.config_cls = config_cls - - def __call__(self, *args, **kwargs): - warnings.warn( - f"`{self.deprecated_name}` is deprecated and will be removed in a future release. " - f"Please use `{self.config_cls.__name__}` instead. Example usage:\n" - f" quantize_(model, {self.config_cls.__name__}(...))" - ) - return self.config_cls(*args, **kwargs) - - """ Helper function for implementing aten op or torch function dispatch and dispatching to these implementations.