Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 69 additions & 5 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
)
from torchao.quantization.quant_primitives import (
MappingType,
Expand Down Expand Up @@ -1872,6 +1874,8 @@ def _test_quantize_api_against_ptq(
base_config: AOBaseConfig,
target_prepare_sqnr: float,
target_convert_sqnr: float,
dtype: torch.dtype = torch.bfloat16,
module_type: str = "linear",
):
"""
Test the following:
Expand All @@ -1884,22 +1888,32 @@ def _test_quantize_api_against_ptq(
quantize_(model, base_config)
"""
torch.manual_seed(self.SEED)
m = M().to(torch.bfloat16).cuda()
example_inputs = (m.example_inputs()[0].to(torch.bfloat16).cuda(),)

if module_type == "linear":
m = M().to(dtype).cuda()
example_inputs = (m.example_inputs()[0].to(dtype).cuda(),)
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear)
elif module_type == "embedding":
m = M3().to(dtype).cuda()
example_inputs = (m.example_inputs()[0].cuda(),)
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding)
else:
raise ValueError(f"Unknown module type {module_type}")

# baseline
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)
quantize_(m_baseline, base_config, filter_fn)
out_baseline = m_baseline(*example_inputs)

# compare prepare
quantize_(m, QATConfig(base_config, step="prepare"))
quantize_(m, QATConfig(base_config, step="prepare"), filter_fn)
out_prepared = m(*example_inputs)
prepare_sqnr = compute_error(out_prepared, out_baseline)

self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)

# compare convert
quantize_(m, QATConfig(base_config, step="convert"))
quantize_(m, QATConfig(base_config, step="convert"), filter_fn)
out_converted = m(*example_inputs)
convert_sqnr = compute_error(out_converted, out_baseline)
self.assertGreaterEqual(convert_sqnr, target_convert_sqnr)
Expand Down Expand Up @@ -1967,6 +1981,56 @@ def test_quantize_api_int8_int4(self):
target_convert_sqnr=float("inf"),
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@parametrize(
"weight_dtype, weight_granularity, dtype",
[
(weight_dtype, weight_granularity, dtype)
for weight_dtype in [getattr(torch, f"int{i}") for i in range(2, 9)]
for weight_granularity in [PerGroup(32), PerAxis(0)]
for dtype in [torch.bfloat16, torch.float32]
],
)
def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
"""
Test the following:
quantize_(model, QATConfig(Int8DynamicActivationIntxWeightConfig(), step="prepare"))
quantize_(model, QATConfig(Int8DynamicActivationIntxWeightConfig(), step="convert"))
"""
self._test_quantize_api_against_ptq(
Int8DynamicActivationIntxWeightConfig(
weight_dtype=weight_dtype, weight_granularity=weight_granularity
),
target_prepare_sqnr=float("inf"),
target_convert_sqnr=float("inf"),
dtype=dtype,
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@parametrize(
"weight_dtype, granularity, dtype, module_type",
[
(weight_dtype, granularity, dtype, module_type)
for weight_dtype in [getattr(torch, f"int{i}") for i in range(2, 9)]
for granularity in [PerGroup(32), PerAxis(0)]
for dtype in [torch.bfloat16, torch.float32]
for module_type in ["linear", "embedding"]
],
)
def test_quantize_api_intx(self, weight_dtype, granularity, dtype, module_type):
"""
Test the following:
quantize_(model, QATConfig(IntxWeightOnlyConfig(), step="prepare"))
quantize_(model, QATConfig(IntxWeightOnlyConfig(), step="convert"))
"""
self._test_quantize_api_against_ptq(
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity),
target_prepare_sqnr=float("inf"),
target_convert_sqnr=float("inf"),
dtype=dtype,
module_type=module_type,
)

def test_infer_fp8_int4_config(self):
"""
Test that fake quantize configs are correctly inferred from
Expand Down
50 changes: 50 additions & 0 deletions torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def _infer_fake_quantize_configs(
Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
)

if isinstance(base_config, Int8DynamicActivationInt4WeightConfig):
Expand Down Expand Up @@ -438,6 +440,54 @@ def _infer_fake_quantize_configs(
else:
act_config = None
weight_config = NVFP4FakeQuantizeConfig(False)
elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig):
assert base_config.version >= 2, "Only version 2+ is supported"
assert base_config.intx_packing_format == "unpacked_to_int8", (
"Only unpacked_to_int8 is supported"
)
assert base_config.weight_dtype != torch.int1, "Only int2+ is supported"
assert base_config.act_mapping_type == MappingType.ASYMMETRIC, (
"Only asymmetric activation mapping is supported"
)
assert base_config.weight_mapping_type == MappingType.SYMMETRIC, (
"Only symmetric weight mapping is supported"
)
assert base_config.weight_scale_dtype is None, (
"Specifying weight_scale_dtype is not supported"
)

act_config = IntxFakeQuantizeConfig(
torch.int8,
"per_token",
is_symmetric=False,
scale_precision=base_config.weight_scale_dtype,
)
weight_config = IntxFakeQuantizeConfig(
dtype=base_config.weight_dtype,
granularity=base_config.weight_granularity,
mapping_type=base_config.weight_mapping_type,
scale_precision=base_config.weight_scale_dtype,
)
elif isinstance(base_config, IntxWeightOnlyConfig):
assert base_config.version >= 2, "Only version 2+ is supported"
assert base_config.intx_packing_format == "unpacked_to_int8", (
"Only unpacked_to_int8 is supported"
)
assert base_config.mapping_type == MappingType.SYMMETRIC, (
"Only symmetric mapping is supported"
)
assert base_config.weight_dtype != torch.int1, "Only int2+ is supported"
assert base_config.scale_dtype is None, (
"Specifying scale_dtype is not supported"
)

act_config = None
weight_config = IntxFakeQuantizeConfig(
dtype=base_config.weight_dtype,
granularity=base_config.granularity,
mapping_type=base_config.mapping_type,
scale_precision=base_config.scale_dtype,
)
else:
raise ValueError("Unexpected base config: %s" % base_config)
return (act_config, weight_config)
Loading