Skip to content

Commit

Permalink
[quant][pt2e][be] Rename qnnpack quantizer to xnnpack quantizer (#105551
Browse files Browse the repository at this point in the history
)

Summary: att

Test Plan: sandcastle CI and OSS CI

Reviewed By: andrewor14

Differential Revision: D47422894

Pull Request resolved: #105551
Approved by: https://github.com/andrewor14
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Jul 20, 2023
1 parent c6653b6 commit dff4e03
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 58 deletions.
72 changes: 34 additions & 38 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
EmbeddingQuantizer,
FixedQParamsQuantizationSpec,
OperatorConfig,
QNNPackQuantizer,
XNNPACKQuantizer,
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
Expand All @@ -30,7 +30,7 @@
from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811
ComposableQuantizer,
)
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import (
from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantize_pt2e import (
Expand Down Expand Up @@ -291,7 +291,7 @@ def _verify_symmetric_qnnpack_qat_numerics(
# PT2 export

model_pt2e = copy.deepcopy(model)
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(
is_per_channel=is_per_channel, is_qat=True
Expand Down Expand Up @@ -354,7 +354,7 @@ def _verify_symmetric_qnnpack_qat_graph(
with fake quantizes inserted into the correct places.
# TODO: also verify that metadata is copied over to the new nodes.
"""
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel, is_qat=True)
)
Expand Down Expand Up @@ -991,8 +991,8 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)

def test_qnnpack_quantizer_conv(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_conv(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(1, 3, 5, 5),)
Expand All @@ -1017,8 +1017,8 @@ def test_qnnpack_quantizer_conv(self):
node_list,
)

def test_qnnpack_quantizer_linear(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_linear(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.TwoLinearModule().eval()
Expand Down Expand Up @@ -1047,8 +1047,8 @@ def test_qnnpack_quantizer_linear(self):
qconfig_mapping,
)

def test_qnnpack_quantizer_conv_linear_no_permute(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_conv_linear_no_permute(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
node_occurrence = {
Expand All @@ -1072,8 +1072,8 @@ def test_qnnpack_quantizer_conv_linear_no_permute(self):
qconfig_mapping,
)

def test_qnnpack_quantizer_conv_linear(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_conv_linear(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)

Expand All @@ -1097,8 +1097,8 @@ def test_qnnpack_quantizer_conv_linear(self):
qconfig_mapping,
)

def test_qnnpack_quantizer_linear_with_dynamic_shape(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_linear_with_dynamic_shape(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.TwoLinearModule().eval()
Expand All @@ -1125,8 +1125,8 @@ def test_qnnpack_quantizer_linear_with_dynamic_shape(self):
export_with_dynamic_shape=True,
)

def test_qnnpack_quantizer_obs_sharing_ops(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_obs_sharing_ops(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = TestHelperModules.Conv2dWithObsSharingOps().eval()
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def test_qnnpack_quantizer_obs_sharing_ops(self):
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)

def test_propagate_annotation(self):
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = TestHelperModules.Conv2dPropAnnotaton().eval()
Expand Down Expand Up @@ -1196,8 +1196,8 @@ def test_propagate_annotation(self):
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)

def test_qnnpack_quantizer_dynamic_linear(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_dynamic_linear(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
Expand Down Expand Up @@ -1239,8 +1239,8 @@ def test_qnnpack_quantizer_dynamic_linear(self):
qconfig_mapping,
)

def test_qnnpack_quantizer_dynamic_linear_with_conv(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_dynamic_linear_with_conv(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=True
)
Expand Down Expand Up @@ -1280,12 +1280,12 @@ def test_qnnpack_quantizer_dynamic_linear_with_conv(self):
)

def test_composable_quantizer_linear_conv(self):
dynamic_quantizer = QNNPackQuantizer()
dynamic_quantizer = XNNPACKQuantizer()
operator_config_dynamic = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=True
)
dynamic_quantizer.set_global(operator_config_dynamic)
static_quantizer = QNNPackQuantizer()
static_quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
static_quantizer.set_global(operator_config)
# Note that dynamic quantization must be applied first here.
Expand Down Expand Up @@ -1348,7 +1348,7 @@ def validate(self, model: torch.fx.GraphModule) -> None:
def get_supported_operators(cls) -> List[OperatorConfig]:
pass

quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
bad_quantizer = BadQuantizer()
Expand Down Expand Up @@ -1472,12 +1472,12 @@ def test_embedding_conv_linear_quantization(self):
example_inputs = (indices,)

embedding_quantizer = EmbeddingQuantizer()
dynamic_quantizer = QNNPackQuantizer()
dynamic_quantizer = XNNPACKQuantizer()
operator_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
dynamic_quantizer.set_global(operator_config_dynamic)
static_quantizer = QNNPackQuantizer()
static_quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
static_quantizer.set_global(operator_config)
composed_quantizer = ComposableQuantizer(
Expand Down Expand Up @@ -1705,7 +1705,7 @@ def _get_getitem_nodes(m: torch.fx.GraphModule):
(_, original_conv_bn_getitem_node) = _get_getitem_nodes(m)

# Prepare QAT
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True)
)
Expand Down Expand Up @@ -1774,10 +1774,8 @@ def __init__(self):
def forward(self, x, y):
return x + y

import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq

quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()

Expand All @@ -1799,10 +1797,8 @@ def __init__(self):
def forward(self, x, y):
return x + y

import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq

quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()

Expand Down Expand Up @@ -1880,7 +1876,7 @@ def forward(self, input_tensor, hidden_tensor):
aten_graph=True,
tracing_mode="real",
)
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=False
)
Expand Down Expand Up @@ -1944,7 +1940,7 @@ def forward(self, input_tensor, hidden_tensor):
aten_graph=True,
tracing_mode="real",
)
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=False
)
Expand Down Expand Up @@ -1972,7 +1968,7 @@ def test_resnet18_with_quantizer_api(self):
aten_graph=True,
)

quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = prepare_pt2e(m, quantizer)
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _module_dir(m: types.ModuleType):
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
FILENAME_ALLOWLIST |= {
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
_module_dir(torch) + "ao/quantization/pt2e/quantizer/qnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/pt2e/quantizer/xnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
}
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/pt2e/_propagate_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def _is_share_obs_or_fq_op(op: Callable) -> bool:
# TODO: remove some of these ops in qnnpack_quantizer
# TODO: remove some of these ops in xnnpack_quantizer
return op in [
torch.ops.aten.hardtanh.default,
torch.ops.aten.mean.default,
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/pt2e/quantizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .qnnpack_quantizer import QNNPackQuantizer
from .xnnpack_quantizer import XNNPACKQuantizer
from .quantizer import (
DerivedQuantizationSpec,
EdgeOrNode,
Expand All @@ -25,7 +25,7 @@
"QuantizationConfig",
"EmbeddingQuantizer",
"Quantizer",
"QNNPackQuantizer",
"XNNPACKQuantizer",
"QuantizationSpecBase",
"QuantizationSpec",
"FixedQParamsQuantizationSpec",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_weight_qspec,
get_bias_qspec,
)
from .qnnpack_quantizer import (
from .xnnpack_quantizer import (
_is_annotated,
)
from torch.ao.quantization.observer import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@
import torch
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)

from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions

Expand All @@ -23,15 +32,6 @@
get_output_act_qspec,
get_weight_qspec,
)
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor

from torch.fx import Node
Expand All @@ -50,7 +50,7 @@


__all__ = [
"QNNPackQuantizer",
"XNNPACKQuantizer",
"get_symmetric_quantization_config",
]

Expand Down Expand Up @@ -224,7 +224,7 @@ def _is_annotated(nodes: List[Node]):
return annotated


class QNNPackQuantizer(Quantizer):
class XNNPACKQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()

def __init__(self):
Expand Down Expand Up @@ -259,13 +259,13 @@ def get_supported_operator_for_quantization_config(
return ops
return []

def set_global(self, quantization_config: QuantizationConfig) -> QNNPackQuantizer:
def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
self.global_config = quantization_config
return self

def set_config_for_operator_type(
self, operator_type: str, quantization_config: QuantizationConfig
) -> QNNPackQuantizer:
) -> XNNPACKQuantizer:
self.operator_type_config[operator_type] = quantization_config
return self

Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
SharedQuantizationSpec,
DerivedQuantizationSpec,
QuantizationAnnotation,
QNNPackQuantizer,
XNNPACKQuantizer,
EmbeddingQuantizer,
ComposableQuantizer,
)
Expand All @@ -38,7 +38,7 @@
get_output_act_qspec,
get_weight_qspec,
)
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401
from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import ( # noqa: F401
get_symmetric_quantization_config,
)
from torch.ao.quantization.backend_config import BackendConfig
Expand Down

0 comments on commit dff4e03

Please sign in to comment.