Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu

We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow

1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference

Expand Down
51 changes: 0 additions & 51 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import gc
import tempfile
import unittest
import warnings
from pathlib import Path

import torch
Expand Down Expand Up @@ -847,56 +846,6 @@ def test_int4wo_cuda_serialization(self):
# load state_dict in cuda
model.load_state_dict(sd, assign=True)

def test_config_deprecation(self):
"""
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
"""
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)

# Reset deprecation warning state, otherwise we won't log warnings here
warnings.resetwarnings()

# Map from deprecated API to the args needed to instantiate it
deprecated_apis_to_args = {
float8_dynamic_activation_float8_weight: (),
float8_static_activation_float8_weight: (torch.randn(3)),
float8_weight_only: (),
fpx_weight_only: (3, 2),
gemlite_uintx_weight_only: (),
int4_dynamic_activation_int4_weight: (),
int4_weight_only: (),
int8_dynamic_activation_int4_weight: (),
int8_dynamic_activation_int8_weight: (),
int8_weight_only: (),
uintx_weight_only: (torch.uint4,),
}

with warnings.catch_warnings(record=True) as _warnings:
# Call each deprecated API twice
for cls, args in deprecated_apis_to_args.items():
cls(*args)
cls(*args)

# Each call should trigger the warning only once
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
for w in _warnings:
self.assertIn(
"is deprecated and will be removed in a future release",
str(w.message),
)


common_utils.instantiate_parametrized_tests(TestQuantFlow)

Expand Down
24 changes: 0 additions & 24 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,9 @@
PlainLayout,
TensorCoreTiledLayout,
UIntXWeightOnlyConfig,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
intx_quantization_aware_training,
quantize_,
swap_conv2d_1x1_to_linear,
uintx_weight_only,
)
from .quant_primitives import (
MappingType,
Expand Down Expand Up @@ -131,19 +119,7 @@
"ALL_AUTOQUANT_CLASS_LIST",
# top level API - manual
"quantize_",
"int4_dynamic_activation_int4_weight",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"intx_quantization_aware_training",
"float8_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight",
"uintx_weight_only",
"fpx_weight_only",
"gemlite_uintx_weight_only",
"swap_conv2d_1x1_to_linear",
"Int4DynamicActivationInt4WeightConfig",
"Int8DynamicActivationInt4WeightConfig",
Expand Down
75 changes: 1 addition & 74 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@
to_weight_tensor_with_linear_activation_quantization_metadata,
)
from torchao.utils import (
_ConfigDeprecationWrapper,
is_MI300,
is_sm_at_least_89,
is_sm_at_least_90,
Expand Down Expand Up @@ -148,18 +147,7 @@
"autoquant",
"_get_subclass_inserter",
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"intx_quantization_aware_training",
"float8_weight_only",
"uintx_weight_only",
"fpx_weight_only",
"gemlite_uintx_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight",
"Int8DynActInt4WeightQuantizer",
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
"ModuleFqnToConfig",
Expand Down Expand Up @@ -507,7 +495,7 @@ def quantize_(
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
from torchao.quantization.quant_api import int4_weight_only
from torchao.quantization.quant_api import Int4WeightOnlyConfig

m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
Expand Down Expand Up @@ -629,12 +617,6 @@ def __post_init__(self):
)


# for BC
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
)


@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
def _int8_dynamic_activation_int4_weight_transform(
module: torch.nn.Module,
Expand Down Expand Up @@ -1000,12 +982,6 @@ def __post_init__(self):
)


# for bc
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
)


@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
def _int4_dynamic_activation_int4_weight_transform(
module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig
Expand Down Expand Up @@ -1063,12 +1039,6 @@ def __post_init__(self):
)


# for BC
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
)


@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
def _gemlite_uintx_weight_only_transform(
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
Expand Down Expand Up @@ -1146,11 +1116,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")


# for BC
# TODO maybe change other callsites
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)


def _int4_weight_only_quantize_tensor(weight, config):
# TODO(future PR): perhaps move this logic to a different file, to keep the API
# file clean of implementation details
Expand Down Expand Up @@ -1362,10 +1327,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")


# for BC
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)


def _int8_weight_only_quantize_tensor(weight, config):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
Expand Down Expand Up @@ -1523,12 +1484,6 @@ def __post_init__(self):
)


# for BC
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
)


def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
layout = config.layout
act_mapping_type = config.act_mapping_type
Expand Down Expand Up @@ -1634,12 +1589,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")


# for BC
float8_weight_only = _ConfigDeprecationWrapper(
"float8_weight_only", Float8WeightOnlyConfig
)


def _float8_weight_only_quant_tensor(weight, config):
if config.version == 1:
warnings.warn(
Expand Down Expand Up @@ -1798,12 +1747,6 @@ def __post_init__(self):
self.granularity = [activation_granularity, weight_granularity]


# for bc
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
)


def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
activation_dtype = config.activation_dtype
weight_dtype = config.weight_dtype
Expand Down Expand Up @@ -1979,12 +1922,6 @@ def __post_init__(self):
)


# for bc
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
)


@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
def _float8_static_activation_float8_weight_transform(
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
Expand Down Expand Up @@ -2067,12 +2004,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")


# for BC
uintx_weight_only = _ConfigDeprecationWrapper(
"uintx_weight_only", UIntXWeightOnlyConfig
)


@register_quantize_module_handler(UIntXWeightOnlyConfig)
def _uintx_weight_only_transform(
module: torch.nn.Module, config: UIntXWeightOnlyConfig
Expand Down Expand Up @@ -2351,10 +2282,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")


# for BC
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)


@register_quantize_module_handler(FPXWeightOnlyConfig)
def _fpx_weight_only_transform(
module: torch.nn.Module, config: FPXWeightOnlyConfig
Expand Down
21 changes: 1 addition & 20 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functools import reduce
from importlib.metadata import version
from math import gcd
from typing import Any, Callable, Optional, Type
from typing import Any, Callable, Optional

import torch
import torch.nn.utils.parametrize as parametrize
Expand Down Expand Up @@ -433,25 +433,6 @@ def __eq__(self, other):
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")


class _ConfigDeprecationWrapper:
"""
A deprecation wrapper that directs users from a deprecated "config function"
(e.g. `int4_weight_only`) to the replacement config class.
"""

def __init__(self, deprecated_name: str, config_cls: Type):
self.deprecated_name = deprecated_name
self.config_cls = config_cls

def __call__(self, *args, **kwargs):
warnings.warn(
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
f" quantize_(model, {self.config_cls.__name__}(...))"
)
return self.config_cls(*args, **kwargs)


"""
Helper function for implementing aten op or torch function dispatch
and dispatching to these implementations.
Expand Down
Loading