diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 66324f928f04..fd8cc8764f89 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -64,7 +64,9 @@ ) from torch.testing._internal.common_quantized import ( + supported_qengines, override_qengines, + override_quantized_engine, ) from torch.testing._internal.common_distributed import skip_if_not_multigpu @@ -78,6 +80,7 @@ import operator import unittest import io +from typing import Callable class TestFuseFx(QuantizationTestCase): def test_fuse_conv_bn_relu(self): @@ -2365,6 +2368,42 @@ def test_rnn(self): [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) + def _test_conv_transpose_impl( + self, float_cls: Callable, q_cls: Callable, data: torch.Tensor): + with override_quantized_engine('qnnpack'): + # Create fp32 versions of FX and Eager models + m1 = torch.nn.Sequential(float_cls(1, 1, 1)) + m2 = torch.nn.Sequential(float_cls(1, 1, 1)) + m2.load_state_dict(m1.state_dict()) + m2 = torch.quantization.QuantWrapper(m2) + # FX graph + q_result1 = self.checkGraphModeFxOp( + m1, (data,), QuantType.STATIC, + expected_node_occurrence={ + ns.call_module(q_cls): 1, + }) + # Eager + m2.qconfig = get_default_qconfig(torch.backends.quantized.engine) + m2.eval() + m2p = torch.quantization.prepare(m2) + m2p(data) + m2q = torch.quantization.convert(m2p) + q_result2 = m2q(data) + # verify results match + self.assertTrue(torch.allclose(q_result1, q_result2)) + + @unittest.skipUnless('qnnpack' in supported_qengines, + "This Pytorch Build has not been built with or does not support QNNPACK") + def test_conv_transpose_1d(self): + self._test_conv_transpose_impl( + torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4)) + + @unittest.skipUnless('qnnpack' in supported_qengines, + "This Pytorch Build has not been built with or does not support QNNPACK") + def test_conv_transpose_2d(self): + self._test_conv_transpose_impl( + torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4)) + class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index a1e601332d4a..46fbed74bdc8 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -537,6 +537,8 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, torch._ops.ops.quantized.instance_norm: ['running_mean', 'running_var', 'use_input_stats', 'momentum'], } +@register_quant_pattern(torch.nn.ConvTranspose1d) +@register_quant_pattern(torch.nn.ConvTranspose2d) @register_quant_pattern(torch.nn.ELU) @register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.Hardswish)