From eea058002524a12dedfdd76ce5a4173239c12fd9 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 21 Dec 2020 16:26:09 -0800 Subject: [PATCH 1/4] fx quant: hook up ConvTranspose{n}d Summary: Quantization of `ConvTranpose{n}d` is supported in Eager mode. This PR adds the support for FX graph mode. Note: this currenlty only works in `qnnpack` because per-channel weights are not supported by quantized conv transpose. In a future PR we should throw an error when someone tries to quantize a ConvTranspose model with per-channel weight observers until this is fixed. Test Plan: ``` python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_1d python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_2d ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/quantization/test_quantize_fx.py | 45 +++++++++++++++++++ .../quantization/fx/quantization_patterns.py | 32 +++++++++++++ 2 files changed, 77 insertions(+) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index e649a06e160d..49f5a929df02 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -78,6 +78,7 @@ import operator import unittest import io +from typing import Callable class TestFuseFx(QuantizationTestCase): def test_fuse_conv_bn_relu(self): @@ -2361,6 +2362,50 @@ def test_rnn(self): [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) + def _test_conv_transpose_impl( + self, float_cls: Callable, q_cls: Callable, data: torch.Tensor): + if 'qnnpack' not in torch.backends.quantized.supported_engines: + return + torch.backends.quantized.engine = 'qnnpack' + # Create fp32 versions of FX and Eager models + m1 = torch.nn.Sequential(float_cls(1, 1, 1)) + m2 = torch.nn.Sequential(float_cls(1, 1, 1)) + m2.load_state_dict(m1.state_dict()) + m2 = torch.quantization.QuantWrapper(m2) + # FX graph + q_result1 = self.checkGraphModeFxOp( + m1, (data,), QuantType.STATIC, + expected_node_occurrence={ + ns.call_module(q_cls): 1, + }) + # Eager + m2.qconfig = get_default_qconfig(torch.backends.quantized.engine) + m2.eval() + m2p = torch.quantization.prepare(m2) + m2p(data) + m2q = torch.quantization.convert(m2p) + q_result2 = m2q(data) + # verify results match + self.assertTrue(torch.allclose(q_result1, q_result2)) + + def test_conv_transpose_1d(self): + self._test_conv_transpose_impl( + torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4)) + + def test_conv_transpose_2d(self): + self._test_conv_transpose_impl( + torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4)) + + # TODO: remove + def test_conv_transpose_eager(self): + m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) + m.qconfig = torch.quantization.get_default_qconfig('fbgemm') + print(m.qconfig) + mp = torch.quantization.prepare(m) + mp(torch.randn(4, 1, 4, 4)) + mq = torch.quantization.convert(mp) + print(mq) + class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index a1e601332d4a..2a414880532a 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -277,6 +277,38 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, return quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs) +@register_quant_pattern(torch.nn.ConvTranspose1d) +@register_quant_pattern(torch.nn.ConvTranspose2d) +class ConvTransposeQuantizeHandler(QuantizeHandler): + def __init__(self, quantizer: QuantizerCls, node: Node): + super().__init__(quantizer, node) + assert node.op == 'call_module', \ + 'ConvTranspose{n}d is not implemented for functionals yet' + self.conv_t_node = node + self.conv_t = quantizer.modules[self.conv_t_node.target] + + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + if convert_custom_config_dict is None: + convert_custom_config_dict = {} + additional_static_quant_mapping = \ + convert_custom_config_dict.get("static", {}) + # 1. attach activation post process to module + self.conv_t.activation_post_process = \ + quantizer.activation_post_process_map[node.name] + qconv_t_cls = get_static_quant_module_class( + type(self.conv_t), additional_static_quant_mapping) + quantized = qconv_t_cls.from_float(self.conv_t) + parent_name, name = _parent_name(self.conv_t_node.target) + setattr(quantizer.modules[parent_name], name, quantized) + return quantizer.quantized_graph.create_node( + 'call_module', + self.conv_t_node.target, + (load_arg(quantized=True)(self.conv_t_node.args[0]),), + {}) + return node + # handle linear, maybe followed by relu @register_quant_pattern(torch.nn.Linear) @register_quant_pattern(torch.nn.functional.linear) From 5f917cf6c0a4bac59ad82fd9ad951f60fa90d071 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 21 Dec 2020 16:28:11 -0800 Subject: [PATCH 2/4] nit on "fx quant: hook up ConvTranspose{n}d" Summary: Quantization of `ConvTranpose{n}d` is supported in Eager mode. This PR adds the support for FX graph mode. Note: this currenlty only works in `qnnpack` because per-channel weights are not supported by quantized conv transpose. In a future PR we should throw an error when someone tries to quantize a ConvTranspose model with per-channel weight observers until this is fixed. Test Plan: ``` python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_1d python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_2d ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25674636](https://our.internmc.facebook.com/intern/diff/D25674636) [ghstack-poisoned] --- test/quantization/test_quantize_fx.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 49f5a929df02..f780915730e8 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -2396,16 +2396,6 @@ def test_conv_transpose_2d(self): self._test_conv_transpose_impl( torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4)) - # TODO: remove - def test_conv_transpose_eager(self): - m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) - m.qconfig = torch.quantization.get_default_qconfig('fbgemm') - print(m.qconfig) - mp = torch.quantization.prepare(m) - mp(torch.randn(4, 1, 4, 4)) - mq = torch.quantization.convert(mp) - print(mq) - class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( From 84e5a08f8fc142f983053e786fe75e835b061f51 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 21 Dec 2020 16:49:38 -0800 Subject: [PATCH 3/4] comments on "fx quant: hook up ConvTranspose{n}d" Summary: Quantization of `ConvTranpose{n}d` is supported in Eager mode. This PR adds the support for FX graph mode. Note: this currenlty only works in `qnnpack` because per-channel weights are not supported by quantized conv transpose. In a future PR we should throw an error when someone tries to quantize a ConvTranspose model with per-channel weight observers until this is fixed. Test Plan: ``` python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_1d python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_2d ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25674636](https://our.internmc.facebook.com/intern/diff/D25674636) [ghstack-poisoned] --- .../quantization/fx/quantization_patterns.py | 42 +++++-------------- 1 file changed, 10 insertions(+), 32 deletions(-) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 2a414880532a..c7eac4554c6c 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -277,38 +277,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, return quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs) -@register_quant_pattern(torch.nn.ConvTranspose1d) -@register_quant_pattern(torch.nn.ConvTranspose2d) -class ConvTransposeQuantizeHandler(QuantizeHandler): - def __init__(self, quantizer: QuantizerCls, node: Node): - super().__init__(quantizer, node) - assert node.op == 'call_module', \ - 'ConvTranspose{n}d is not implemented for functionals yet' - self.conv_t_node = node - self.conv_t = quantizer.modules[self.conv_t_node.target] - - def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, - debug: bool = False, - convert_custom_config_dict: Dict[str, Any] = None) -> Node: - if convert_custom_config_dict is None: - convert_custom_config_dict = {} - additional_static_quant_mapping = \ - convert_custom_config_dict.get("static", {}) - # 1. attach activation post process to module - self.conv_t.activation_post_process = \ - quantizer.activation_post_process_map[node.name] - qconv_t_cls = get_static_quant_module_class( - type(self.conv_t), additional_static_quant_mapping) - quantized = qconv_t_cls.from_float(self.conv_t) - parent_name, name = _parent_name(self.conv_t_node.target) - setattr(quantizer.modules[parent_name], name, quantized) - return quantizer.quantized_graph.create_node( - 'call_module', - self.conv_t_node.target, - (load_arg(quantized=True)(self.conv_t_node.args[0]),), - {}) - return node - # handle linear, maybe followed by relu @register_quant_pattern(torch.nn.Linear) @register_quant_pattern(torch.nn.functional.linear) @@ -627,6 +595,16 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, return quantizer.quantized_graph.create_node( "call_function", quantized_op, args, kwargs) +@register_quant_pattern(torch.nn.ConvTranspose1d) +@register_quant_pattern(torch.nn.ConvTranspose2d) +# thin wrapper around DefaultNode for a nice error message for functionals, until +# they are supported +class ConvTransposeQuantizeHandler(DefaultNode): + def __init__(self, quantizer: QuantizerCls, node: Node): + super().__init__(quantizer, node) + assert node.op == 'call_module', \ + 'ConvTranspose{n}d is not implemented for functionals yet' + # TODO: elu is using scale/zero_point instead of output_scale, output_zero_point @register_quant_pattern(torch.nn.functional.elu) class ELU(QuantizeHandler): From 36e0f696df558ac187be1c55a4e5f9a76ffd0dfc Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 21 Dec 2020 16:54:16 -0800 Subject: [PATCH 4/4] Update on "fx quant: hook up ConvTranspose{n}d" Summary: Quantization of `ConvTranpose{n}d` is supported in Eager mode. This PR adds the support for FX graph mode. Note: this currenlty only works in `qnnpack` because per-channel weights are not supported by quantized conv transpose. In a future PR we should throw an error when someone tries to quantize a ConvTranspose model with per-channel weight observers until this is fixed. Test Plan: ``` python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_1d python test/test_quantization.py TestQuantizeFxOps.test_conv_transpose_2d ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25674636](https://our.internmc.facebook.com/intern/diff/D25674636) [ghstack-poisoned] --- torch/quantization/fx/quantization_patterns.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index c7eac4554c6c..46fbed74bdc8 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -537,6 +537,8 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, torch._ops.ops.quantized.instance_norm: ['running_mean', 'running_var', 'use_input_stats', 'momentum'], } +@register_quant_pattern(torch.nn.ConvTranspose1d) +@register_quant_pattern(torch.nn.ConvTranspose2d) @register_quant_pattern(torch.nn.ELU) @register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.Hardswish) @@ -595,16 +597,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, return quantizer.quantized_graph.create_node( "call_function", quantized_op, args, kwargs) -@register_quant_pattern(torch.nn.ConvTranspose1d) -@register_quant_pattern(torch.nn.ConvTranspose2d) -# thin wrapper around DefaultNode for a nice error message for functionals, until -# they are supported -class ConvTransposeQuantizeHandler(DefaultNode): - def __init__(self, quantizer: QuantizerCls, node: Node): - super().__init__(quantizer, node) - assert node.op == 'call_module', \ - 'ConvTranspose{n}d is not implemented for functionals yet' - # TODO: elu is using scale/zero_point instead of output_scale, output_zero_point @register_quant_pattern(torch.nn.functional.elu) class ELU(QuantizeHandler):