Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/api_ref_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ Legacy QAT APIs
:toctree: generated/
:nosignatures:

IntXQuantizationAwareTrainingConfig
FromIntXQuantizationAwareTrainingConfig
Int4WeightOnlyQATQuantizer
linear.Int4WeightOnlyQATLinear
Int8DynActInt4WeightQATQuantizer
Expand Down
11 changes: 5 additions & 6 deletions test/prototype/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
Int4WeightOnlyEmbeddingQATQuantizer,
IntxFakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
Expand Down Expand Up @@ -257,7 +256,7 @@ def test_identical_to_IntxWeightOnlyConfig(
],
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
)
def test_identical_to_IntXQuantizationAwareTrainingConfig(
def test_identical_to_QATConfig(
self, weight_dtype, granularity, mapping_type, scale_dtype, model_dtype
):
# ASYMMETRIC in QAT is very different that PTQ configs
Expand Down Expand Up @@ -288,12 +287,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
)
quantize_(
model,
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
QATConfig(weight_config=weight_config, step="prepare"),
embedding_filter,
)
prepared_out = model(indices)

quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
quantize_(model, QATConfig(step="convert"), embedding_filter)
quantize_(
model,
IntxWeightOnlyConfig(
Expand Down Expand Up @@ -355,7 +354,7 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer(
prepared_out = model(indices)

# Convert model method 1
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
quantize_(model, QATConfig(step="convert"), embedding_filter)
quantize_(
model,
IntxWeightOnlyConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
Int8DynActInt4WeightQATQuantizer,
IntxFakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationInt4WeightConfig,
Expand Down Expand Up @@ -499,7 +498,7 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
],
)
def test_identical_to_IntXQuantizationAwareTrainingConfig(
def test_identical_to_QATConfig(
self,
weight_dtype,
group_size,
Expand Down Expand Up @@ -545,7 +544,11 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(

quantize_(
model,
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
),
)
try:
prepared_out = model(activations)
Expand All @@ -555,7 +558,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
return
raise e

quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, QATConfig(step="convert"))
quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
Expand Down Expand Up @@ -606,7 +609,7 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer(
prepared_out = model(activations)

# Convert model method 1
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, QATConfig(step="convert"))
quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
Expand Down
92 changes: 0 additions & 92 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import copy
import unittest
import warnings
from typing import List, Type

import torch
Expand Down Expand Up @@ -39,8 +38,6 @@
)
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
QATStep,
initialize_fake_quantizers,
Expand Down Expand Up @@ -1718,95 +1715,6 @@ def test_qat_fp8a4w_quantizer(self):
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
self.assertFalse(torch.equal(new_weight, prev_weight))

def test_legacy_quantize_api_e2e(self):
"""
Test that the following two APIs are numerically equivalent:

New API:
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))

Old API:
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig())
"""
group_size = 16
torch.manual_seed(self.SEED)
m = M()
baseline_model = copy.deepcopy(m)

# Baseline prepare
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
quantize_(baseline_model, old_qat_config)

# QATConfig prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
quantize_(m, QATConfig(base_config, step="prepare"))

# Compare prepared values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

# Baseline convert
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
quantize_(baseline_model, base_config)

# quantize_ convert
quantize_(m, QATConfig(base_config, step="convert"))

# Compare converted values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

def test_qat_api_deprecation(self):
"""
Test that the appropriate deprecation warning is logged exactly once per class.
"""
from torchao.quantization.qat import (
FakeQuantizeConfig,
FakeQuantizer,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)

# Reset deprecation warning state, otherwise we won't log warnings here
warnings.resetwarnings()

# Map from deprecated API to the args needed to instantiate it
deprecated_apis_to_args = {
IntXQuantizationAwareTrainingConfig: (),
FromIntXQuantizationAwareTrainingConfig: (),
intx_quantization_aware_training: (),
from_intx_quantization_aware_training: (),
FakeQuantizeConfig: (torch.int8, "per_channel"),
FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),),
}

with warnings.catch_warnings(record=True) as _warnings:
# Call each deprecated API twice
for cls, args in deprecated_apis_to_args.items():
cls(*args)
cls(*args)

# Each call should trigger the warning only once
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
for w in _warnings:
self.assertIn(
"is deprecated and will be removed in a future release",
str(w.message),
)

def test_qat_api_convert_no_quantization(self):
"""
Test that `QATConfig(step="convert")` swaps back to nn modules without quantization.
Expand Down
2 changes: 0 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
PlainLayout,
TensorCoreTiledLayout,
UIntXWeightOnlyConfig,
intx_quantization_aware_training,
quantize_,
swap_conv2d_1x1_to_linear,
)
Expand Down Expand Up @@ -119,7 +118,6 @@
"ALL_AUTOQUANT_CLASS_LIST",
# top level API - manual
"quantize_",
"intx_quantization_aware_training",
"swap_conv2d_1x1_to_linear",
"Int4DynamicActivationInt4WeightConfig",
"Int8DynamicActivationInt4WeightConfig",
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/prototype/qat/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
IntxFakeQuantizeConfig as FakeQuantizeConfig,
)

__all__ = [
Expand Down
13 changes: 0 additions & 13 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
from .api import (
ComposableQATQuantizer,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
QATStep,
from_intx_quantization_aware_training,
initialize_fake_quantizers,
intx_quantization_aware_training,
)
from .embedding import (
FakeQuantizedEmbedding,
Int4WeightOnlyEmbeddingQATQuantizer,
)
from .fake_quantize_config import (
FakeQuantizeConfig,
FakeQuantizeConfigBase,
Float8FakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
from .fake_quantizer import (
FakeQuantizer,
FakeQuantizerBase,
Float8FakeQuantizer,
IntxFakeQuantizer,
Expand Down Expand Up @@ -50,11 +44,4 @@
"Int4WeightOnlyEmbeddingQATQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
# for BC
"FakeQuantizer",
"FakeQuantizeConfig",
"from_intx_quantization_aware_training",
"FromIntXQuantizationAwareTrainingConfig",
"intx_quantization_aware_training",
"IntXQuantizationAwareTrainingConfig",
]
Loading
Loading