From 118c96fd6433abca9f13dd9e5c47753259021d8b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 1 Dec 2020 15:15:40 -0800 Subject: [PATCH] Update on "[reland][quant][fix] Add bias once in conv_fused (#48593)" Summary: Previously _conv_forward will add self.bias to the result, so bias is added twice in qat ConvBn module this PR added a bias argument to _conv_forward and _conv_forward is called with zero bias in ConvBn module fixes: https://github.com/pytorch/pytorch/issues/48514 Test Plan: Imported from OSS Reviewed By: raghuramank100 Differential Revision: [D25249175](https://our.internmc.facebook.com/intern/diff/D25249175) [ghstack-poisoned] --- test/quantization/test_qat_module.py | 5 ++++- torch/nn/intrinsic/qat/modules/conv_fused.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_qat_module.py b/test/quantization/test_qat_module.py index 015cc9117797..d71fdaf9d21f 100644 --- a/test/quantization/test_qat_module.py +++ b/test/quantization/test_qat_module.py @@ -110,7 +110,10 @@ def _forward(self, input): running_std = torch.sqrt(self.running_var + self.eps) scale_factor = self.gamma / running_std scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1]) - zero_bias = torch.zeros_like(self.bias) + if self.bias: + zero_bias = torch.zeros_like(self.bias) + else: + zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device()) conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias) if self.training and not self.freeze_bn: diff --git a/torch/nn/intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py index 7b2163e9eb20..dc6260c7d4a5 100644 --- a/torch/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -94,7 +94,10 @@ def _forward(self, input): bias_shape[1] = -1 scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape)) # this does not include the conv bias - zero_bias = torch.zeros_like(self.bias) + if self.bias: + zero_bias = torch.zeros_like(self.bias) + else: + zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device()) conv = self._conv_forward(input, scaled_weight, zero_bias) conv_orig = conv / scale_factor.reshape(bias_shape) if self.bias is not None: