From 19f972b6964eb59bb1a6c09d4bafe015b86fc45c Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 22 Dec 2020 16:47:34 -0800 Subject: [PATCH] fx quant: do not observe bias on F.linear (#49628) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49628 Ensures that linear bias is not observed in a `F.linear` call. This should be a small speedup in PTQ, and will change numerics (in a good way) for QAT if someone is using `F.linear`. Note: the implementation is slightly more verbose compared to conv because bias is a keyword argument in Linear. Test Plan: ``` python test/test_quantization.py TestQuantizeFxOps.test_linear_functional_bias_not_observed ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25653532 fbshipit-source-id: c93501bf6b55cbe4a11cfdad6f79313483133a39 --- test/quantization/test_quantize_fx.py | 14 ++++++++++++++ torch/quantization/fx/quantize.py | 26 ++++++++++++++------------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 14d66a9a119c..7b7b5ffb83a0 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1387,6 +1387,20 @@ def test_conv2d_functional(self): expected_node_occurrence=expected_node_occurrence, ) + def test_linear_functional_bias_not_observed(self): + data = (torch.rand((1, 4), dtype=torch.float),) + for bias in [True, False]: + linear = torch.nn.Linear(4, 4, bias=bias) + # There should be 3 observers: after input, weight and activation. + expected_node_occurrence = { + ns.call_module(torch.quantization.HistogramObserver): 2, + ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1, + } + self.checkGraphModeFxOp( + linear, data, QuantType.STATIC, + prepare_expected_node_occurrence=expected_node_occurrence, + ) + @skipIfNoFBGEMM def test_quantized_conv_relu(self): """tests for conv1d_relu/conv2d_relu/conv3d_relu""" diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 363191488839..2cdd7b59b314 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -249,21 +249,23 @@ def node_arg_is_weight(node: Node, arg: Any) -> bool: return True return False -# A dictionary for querying the weight index for a given op -# TODO(future PR): handle linear -BIAS_INDEX_DICT = { - torch.nn.functional.conv1d : [2], - torch.nn.functional.conv2d : [2], - torch.nn.functional.conv3d : [2], +CONV_OPS_WITH_BIAS = { + torch.nn.functional.conv1d, + torch.nn.functional.conv2d, + torch.nn.functional.conv3d, } +CONV_BIAS_ARG_INDEX = 2 def node_arg_is_bias(node: Node, arg: Any) -> bool: - if isinstance(node, Node) and node.op == 'call_function' and \ - node.target in BIAS_INDEX_DICT: - for i, node_arg in enumerate(node.args): - if arg is node_arg and i in \ - BIAS_INDEX_DICT[node.target]: # type: ignore - return True + if isinstance(node, Node) and node.op == 'call_function': + if node.target in CONV_OPS_WITH_BIAS: + for i, node_arg in enumerate(node.args): + if arg is node_arg and i == CONV_BIAS_ARG_INDEX: + return True + elif node.target is torch.nn.functional.linear: + for kwarg_name, kwarg_value in node.kwargs.items(): + if kwarg_name == 'bias' and arg is kwarg_value: + return True return False # weight prepacking ops