From c3a7591cef525aa46df8176fba01d15d37d65828 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 22 Dec 2020 16:47:34 -0800 Subject: [PATCH] fx quant: do not observe bias on F.conv (#49623) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49623 (not ready for review) Ensures that conv bias is not observed in a `F.conv{n}d` call. Test Plan: Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25652856 fbshipit-source-id: 884f87be1948d3e049a557d79bec3c90aec34340 --- test/quantization/test_quantize_fx.py | 26 ++++++++-- torch/quantization/fx/quantize.py | 47 +++++++++++++++---- .../testing/_internal/common_quantization.py | 12 +++-- 3 files changed, 68 insertions(+), 17 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 98283e713747..14d66a9a119c 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1346,12 +1346,12 @@ def forward(self, x): self.checkGraphModeFxOp(model, data, quant_type, quantized_node) @skipIfNoFBGEMM - def test_quantized_conv(self): + def test_conv_module(self): conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} - class Conv(torch.nn.Module): + class ConvWrapper(torch.nn.Module): def __init__(self, dim): - super(Conv, self).__init__() + super(ConvWrapper, self).__init__() self.conv = conv_module[dim](3, 3, 3).float() def forward(self, x): @@ -1366,9 +1366,27 @@ def forward(self, x): } for dim, quant_type in options: model = self.checkGraphModeFxOp( - Conv(dim), self.img_data_dict[dim], quant_type, + ConvWrapper(dim), self.img_data_dict[dim], quant_type, quantized_nodes[dim]) + @skipIfNoFBGEMM + def test_conv2d_functional(self): + for bias in [True, False]: + conv = torch.nn.Conv2d(1, 1, 1, bias=bias) + # There should be 3 observers: after input, weight and activation. + # No observer after bias. + prepare_expected_node_occurrence = { + ns.call_module(torch.quantization.HistogramObserver): 2, + ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1, + } + expected_node_occurrence = \ + {ns.call_function(torch.ops.quantized.conv2d): 1} + self.checkGraphModeFxOp( + conv, (torch.randn(4, 1, 4, 4),), QuantType.STATIC, + prepare_expected_node_occurrence=prepare_expected_node_occurrence, + expected_node_occurrence=expected_node_occurrence, + ) + @skipIfNoFBGEMM def test_quantized_conv_relu(self): """tests for conv1d_relu/conv2d_relu/conv3d_relu""" diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 3d4a92323067..363191488839 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -234,10 +234,38 @@ def insert_observer_for_input_arg_of_observed_node( # A dictionary for querying the weight index for a given op WEIGHT_INDEX_DICT = { + torch.nn.functional.conv1d : [1], torch.nn.functional.conv2d : [1], + torch.nn.functional.conv3d : [1], torch.nn.functional.linear : [1], } +def node_arg_is_weight(node: Node, arg: Any) -> bool: + if isinstance(node, Node) and node.op == 'call_function' and \ + node.target in WEIGHT_INDEX_DICT: + for i, node_arg in enumerate(node.args): + if arg is node_arg and i in \ + WEIGHT_INDEX_DICT[node.target]: # type: ignore + return True + return False + +# A dictionary for querying the weight index for a given op +# TODO(future PR): handle linear +BIAS_INDEX_DICT = { + torch.nn.functional.conv1d : [2], + torch.nn.functional.conv2d : [2], + torch.nn.functional.conv3d : [2], +} + +def node_arg_is_bias(node: Node, arg: Any) -> bool: + if isinstance(node, Node) and node.op == 'call_function' and \ + node.target in BIAS_INDEX_DICT: + for i, node_arg in enumerate(node.args): + if arg is node_arg and i in \ + BIAS_INDEX_DICT[node.target]: # type: ignore + return True + return False + # weight prepacking ops WEIGHT_PREPACK_OPS = { torch._ops.ops.quantized.linear_prepack, @@ -956,15 +984,16 @@ def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult], def visit(node, matched_pattern, qconfig): def visit_arg(arg): - is_weight = False - if isinstance(node, Node) and node.op == 'call_function' and \ - node.target in WEIGHT_INDEX_DICT: - for i, node_arg in enumerate(node.args): - if arg is node_arg and i in \ - WEIGHT_INDEX_DICT[node.target]: # type: ignore - is_weight = True - if qconfig is not None and \ - (activation_is_statically_quantized(qconfig) or is_weight): + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not (is_weight or is_bias) + should_add_handler = qconfig is not None and ( + (is_activation and + activation_is_statically_quantized(qconfig)) or + (is_weight and weight_is_statically_quantized(qconfig)) + ) + + if should_add_handler: act_post_process_ctr = qconfig.weight if is_weight else \ qconfig.activation # overwrite the constructor from qconfig diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index e05425eb67a2..eef9381d79d9 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -672,6 +672,13 @@ def checkGraphModeFxOp(self, model, inputs, quant_type, if not quant_type == QuantType.DYNAMIC: prepared(*inputs) + if print_debug_info: + print() + print('quant type:\n', quant_type) + print('original model:\n', model) + print() + print('prepared model:\n', prepared) + self.checkGraphModuleNodes( prepared, prepare_expected_node, prepare_expected_node_occurrence, prepare_expected_node_list) @@ -685,10 +692,7 @@ def checkGraphModeFxOp(self, model, inputs, quant_type, qgraph_to_check = qgraph_debug if debug else qgraph if print_debug_info: print() - print('quant type:', quant_type) - print('original model:', model) - print() - print('quantized model:', qgraph_to_check) + print('quantized model:\n', qgraph_to_check) self.printGraphModule(qgraph_to_check) print() self.checkGraphModuleNodes(