diff --git a/torch/nn/intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py index dc6260c7d4a5..b1be8e141fce 100644 --- a/torch/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -93,7 +93,8 @@ def _forward(self, input): bias_shape = [1] * len(self.weight.shape) bias_shape[1] = -1 scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape)) - # this does not include the conv bias + # using zero bias here since the bias for original conv + # will be added later if self.bias: zero_bias = torch.zeros_like(self.bias) else: