From ea558b21353a8ed27e864aecab796c2d0d74239f Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 28 Dec 2020 14:21:55 -0800 Subject: [PATCH] fx quant: hook up ConvTranspose{n}d (#49717) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49717 Quantization of `ConvTranpose{n}d` is supported in Eager mode. This PR adds the support for FX graph mode. Note: this currenlty only works in `qnnpack` because per-channel weights are not supported by quantized conv transpose. In a future PR we should throw an error when someone tries to quantize a ConvTranspose model with per-channel weight observers until this is fixed. Test Plan: ``` python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_1d python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_2d ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25674636 fbshipit-source-id: b6948156123ed55db77e6337bea10db956215ae6 --- test/quantization/test_quantize_fx.py | 39 +++++++++++++++++++ .../quantization/fx/quantization_patterns.py | 2 + 2 files changed, 41 insertions(+) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 66324f928f04..fd8cc8764f89 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -64,7 +64,9 @@ ) from torch.testing._internal.common_quantized import ( + supported_qengines, override_qengines, + override_quantized_engine, ) from torch.testing._internal.common_distributed import skip_if_not_multigpu @@ -78,6 +80,7 @@ import operator import unittest import io +from typing import Callable class TestFuseFx(QuantizationTestCase): def test_fuse_conv_bn_relu(self): @@ -2365,6 +2368,42 @@ def test_rnn(self): [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) + def _test_conv_transpose_impl( + self, float_cls: Callable, q_cls: Callable, data: torch.Tensor): + with override_quantized_engine('qnnpack'): + # Create fp32 versions of FX and Eager models + m1 = torch.nn.Sequential(float_cls(1, 1, 1)) + m2 = torch.nn.Sequential(float_cls(1, 1, 1)) + m2.load_state_dict(m1.state_dict()) + m2 = torch.quantization.QuantWrapper(m2) + # FX graph + q_result1 = self.checkGraphModeFxOp( + m1, (data,), QuantType.STATIC, + expected_node_occurrence={ + ns.call_module(q_cls): 1, + }) + # Eager + m2.qconfig = get_default_qconfig(torch.backends.quantized.engine) + m2.eval() + m2p = torch.quantization.prepare(m2) + m2p(data) + m2q = torch.quantization.convert(m2p) + q_result2 = m2q(data) + # verify results match + self.assertTrue(torch.allclose(q_result1, q_result2)) + + @unittest.skipUnless('qnnpack' in supported_qengines, + "This Pytorch Build has not been built with or does not support QNNPACK") + def test_conv_transpose_1d(self): + self._test_conv_transpose_impl( + torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4)) + + @unittest.skipUnless('qnnpack' in supported_qengines, + "This Pytorch Build has not been built with or does not support QNNPACK") + def test_conv_transpose_2d(self): + self._test_conv_transpose_impl( + torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4)) + class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index a1e601332d4a..46fbed74bdc8 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -537,6 +537,8 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, torch._ops.ops.quantized.instance_norm: ['running_mean', 'running_var', 'use_input_stats', 'momentum'], } +@register_quant_pattern(torch.nn.ConvTranspose1d) +@register_quant_pattern(torch.nn.ConvTranspose2d) @register_quant_pattern(torch.nn.ELU) @register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.Hardswish)