Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][pt2e][be] Rename qnnpack quantizer to xnnpack quantizer #105551

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 34 additions & 38 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
EmbeddingQuantizer,
FixedQParamsQuantizationSpec,
OperatorConfig,
QNNPackQuantizer,
XNNPACKQuantizer,
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
Expand All @@ -30,7 +30,7 @@
from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811
ComposableQuantizer,
)
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import (
from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantize_pt2e import (
Expand Down Expand Up @@ -291,7 +291,7 @@ def _verify_symmetric_qnnpack_qat_numerics(
# PT2 export

model_pt2e = copy.deepcopy(model)
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(
is_per_channel=is_per_channel, is_qat=True
Expand Down Expand Up @@ -354,7 +354,7 @@ def _verify_symmetric_qnnpack_qat_graph(
with fake quantizes inserted into the correct places.
# TODO: also verify that metadata is copied over to the new nodes.
"""
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel, is_qat=True)
)
Expand Down Expand Up @@ -991,8 +991,8 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)

def test_qnnpack_quantizer_conv(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_conv(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(1, 3, 5, 5),)
Expand All @@ -1017,8 +1017,8 @@ def test_qnnpack_quantizer_conv(self):
node_list,
)

def test_qnnpack_quantizer_linear(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_linear(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.TwoLinearModule().eval()
Expand Down Expand Up @@ -1047,8 +1047,8 @@ def test_qnnpack_quantizer_linear(self):
qconfig_mapping,
)

def test_qnnpack_quantizer_conv_linear_no_permute(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_conv_linear_no_permute(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
node_occurrence = {
Expand All @@ -1072,8 +1072,8 @@ def test_qnnpack_quantizer_conv_linear_no_permute(self):
qconfig_mapping,
)

def test_qnnpack_quantizer_conv_linear(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_conv_linear(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)

Expand All @@ -1097,8 +1097,8 @@ def test_qnnpack_quantizer_conv_linear(self):
qconfig_mapping,
)

def test_qnnpack_quantizer_linear_with_dynamic_shape(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_linear_with_dynamic_shape(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.TwoLinearModule().eval()
Expand All @@ -1125,8 +1125,8 @@ def test_qnnpack_quantizer_linear_with_dynamic_shape(self):
export_with_dynamic_shape=True,
)

def test_qnnpack_quantizer_obs_sharing_ops(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_obs_sharing_ops(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = TestHelperModules.Conv2dWithObsSharingOps().eval()
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def test_qnnpack_quantizer_obs_sharing_ops(self):
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)

def test_propagate_annotation(self):
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = TestHelperModules.Conv2dPropAnnotaton().eval()
Expand Down Expand Up @@ -1196,8 +1196,8 @@ def test_propagate_annotation(self):
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)

def test_qnnpack_quantizer_dynamic_linear(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_dynamic_linear(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
Expand Down Expand Up @@ -1239,8 +1239,8 @@ def test_qnnpack_quantizer_dynamic_linear(self):
qconfig_mapping,
)

def test_qnnpack_quantizer_dynamic_linear_with_conv(self):
quantizer = QNNPackQuantizer()
def test_xnnpack_quantizer_dynamic_linear_with_conv(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=True
)
Expand Down Expand Up @@ -1280,12 +1280,12 @@ def test_qnnpack_quantizer_dynamic_linear_with_conv(self):
)

def test_composable_quantizer_linear_conv(self):
dynamic_quantizer = QNNPackQuantizer()
dynamic_quantizer = XNNPACKQuantizer()
operator_config_dynamic = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=True
)
dynamic_quantizer.set_global(operator_config_dynamic)
static_quantizer = QNNPackQuantizer()
static_quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
static_quantizer.set_global(operator_config)
# Note that dynamic quantization must be applied first here.
Expand Down Expand Up @@ -1348,7 +1348,7 @@ def validate(self, model: torch.fx.GraphModule) -> None:
def get_supported_operators(cls) -> List[OperatorConfig]:
pass

quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
bad_quantizer = BadQuantizer()
Expand Down Expand Up @@ -1472,12 +1472,12 @@ def test_embedding_conv_linear_quantization(self):
example_inputs = (indices,)

embedding_quantizer = EmbeddingQuantizer()
dynamic_quantizer = QNNPackQuantizer()
dynamic_quantizer = XNNPACKQuantizer()
operator_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
dynamic_quantizer.set_global(operator_config_dynamic)
static_quantizer = QNNPackQuantizer()
static_quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
static_quantizer.set_global(operator_config)
composed_quantizer = ComposableQuantizer(
Expand Down Expand Up @@ -1705,7 +1705,7 @@ def _get_getitem_nodes(m: torch.fx.GraphModule):
(_, original_conv_bn_getitem_node) = _get_getitem_nodes(m)

# Prepare QAT
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True)
)
Expand Down Expand Up @@ -1774,10 +1774,8 @@ def __init__(self):
def forward(self, x, y):
return x + y

import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq

quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()

Expand All @@ -1799,10 +1797,8 @@ def __init__(self):
def forward(self, x, y):
return x + y

import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq

quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()

Expand Down Expand Up @@ -1880,7 +1876,7 @@ def forward(self, input_tensor, hidden_tensor):
aten_graph=True,
tracing_mode="real",
)
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=False
)
Expand Down Expand Up @@ -1944,7 +1940,7 @@ def forward(self, input_tensor, hidden_tensor):
aten_graph=True,
tracing_mode="real",
)
quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=False
)
Expand Down Expand Up @@ -1972,7 +1968,7 @@ def test_resnet18_with_quantizer_api(self):
aten_graph=True,
)

quantizer = QNNPackQuantizer()
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = prepare_pt2e(m, quantizer)
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _module_dir(m: types.ModuleType):
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
FILENAME_ALLOWLIST |= {
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
_module_dir(torch) + "ao/quantization/pt2e/quantizer/qnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/pt2e/quantizer/xnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
}
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/pt2e/_propagate_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def _is_share_obs_or_fq_op(op: Callable) -> bool:
# TODO: remove some of these ops in qnnpack_quantizer
# TODO: remove some of these ops in xnnpack_quantizer
return op in [
torch.ops.aten.hardtanh.default,
torch.ops.aten.mean.default,
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/pt2e/quantizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .qnnpack_quantizer import QNNPackQuantizer
from .xnnpack_quantizer import XNNPACKQuantizer
from .quantizer import (
DerivedQuantizationSpec,
EdgeOrNode,
Expand All @@ -25,7 +25,7 @@
"QuantizationConfig",
"EmbeddingQuantizer",
"Quantizer",
"QNNPackQuantizer",
"XNNPACKQuantizer",
"QuantizationSpecBase",
"QuantizationSpec",
"FixedQParamsQuantizationSpec",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_weight_qspec,
get_bias_qspec,
)
from .qnnpack_quantizer import (
from .xnnpack_quantizer import (
_is_annotated,
)
from torch.ao.quantization.observer import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@
import torch
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)

from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions

Expand All @@ -23,15 +32,6 @@
get_output_act_qspec,
get_weight_qspec,
)
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor

from torch.fx import Node
Expand All @@ -50,7 +50,7 @@


__all__ = [
"QNNPackQuantizer",
"XNNPACKQuantizer",
"get_symmetric_quantization_config",
]

Expand Down Expand Up @@ -224,7 +224,7 @@ def _is_annotated(nodes: List[Node]):
return annotated


class QNNPackQuantizer(Quantizer):
class XNNPACKQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()

def __init__(self):
Expand Down Expand Up @@ -259,13 +259,13 @@ def get_supported_operator_for_quantization_config(
return ops
return []

def set_global(self, quantization_config: QuantizationConfig) -> QNNPackQuantizer:
def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
self.global_config = quantization_config
return self

def set_config_for_operator_type(
self, operator_type: str, quantization_config: QuantizationConfig
) -> QNNPackQuantizer:
) -> XNNPACKQuantizer:
self.operator_type_config[operator_type] = quantization_config
return self

Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
SharedQuantizationSpec,
DerivedQuantizationSpec,
QuantizationAnnotation,
QNNPackQuantizer,
XNNPACKQuantizer,
EmbeddingQuantizer,
ComposableQuantizer,
)
Expand All @@ -38,7 +38,7 @@
get_output_act_qspec,
get_weight_qspec,
)
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401
from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import ( # noqa: F401
get_symmetric_quantization_config,
)
from torch.ao.quantization.backend_config import BackendConfig
Expand Down