Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# ruff: noqa: F841


import copy
import unittest

import torch
Expand All @@ -21,6 +22,7 @@
weight_observer_range_neg_127_to_127,
)
from torch.fx import Node
from torch.testing import FileCheck
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
)
Expand Down Expand Up @@ -1630,6 +1632,101 @@ def forward(self, x):
if key != FROM_NODE_KEY:
self.assertEqual(n.meta[key], weight_meta[key])

def test_constant_folding_pass(self):
from torchao.quantization import (
MappingType,
PerGroup,
PerToken,
)
from torchao.quantization.pt2e._affine_quantization import (
AffineQuantizedMinMaxObserver,
)
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
)

class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.linear.default
):
input_act = node.args[0]
assert isinstance(input_act, torch.fx.Node)
weight = node.args[1]
assert isinstance(weight, torch.fx.Node)

act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=None,
is_dynamic=False,
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
# TODO: maybe align the arg name here
target_dtype=torch.uint8,
mapping_type=MappingType.SYMMETRIC,
granularity=PerToken(),
),
)

weight_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=None,
is_dynamic=False,
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
target_dtype=torch.uint8,
mapping_type=MappingType.SYMMETRIC,
granularity=PerGroup(group_size=128),
),
)
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
weight: weight_qspec,
},
_annotated=True,
)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(128, 20)

def forward(self, x):
return self.linear(x)

example_inputs = (torch.randn(5, 128),)
model = M()
quantizer = BackendAQuantizer()
m = torch.export.export(model.eval(), example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Calibration
m(*example_inputs)
# Get the quantized model
m_fold = copy.deepcopy(m)
m_fold = convert_pt2e(m_fold, fold_quantize=True)

# If fold, check the graph only contains frozed params and no linear_weight
FileCheck().check("_frozen_param0").check_not("linear_weight").run(m_fold.code)

m_not_fold = copy.deepcopy(m)
m_not_fold = convert_pt2e(m_not_fold, fold_quantize=False)

# If not fold, check the graph doesn't contain frozed params and contain linear_weight
FileCheck().check_not("_frozen_param0").check("linear_weight").run(
m_not_fold.code
)

def test_save_load(self):
"""Test save/load a quantized model"""
m = self._get_pt2e_quantized_linear()
Expand Down
Loading