diff --git a/test/quantization/test_qat_module.py b/test/quantization/test_qat_module.py index 4144c0744104..32de0ff50f0e 100644 --- a/test/quantization/test_qat_module.py +++ b/test/quantization/test_qat_module.py @@ -110,7 +110,11 @@ def _forward(self, input): running_std = torch.sqrt(self.running_var + self.eps) scale_factor = self.gamma / running_std scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1]) - conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight)) + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias) + else: + zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device) + conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias) if self.training and not self.freeze_bn: # recovering original conv to get original batch_mean and batch_var diff --git a/torch/nn/intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py index 659d284b2afd..12018a34e23f 100644 --- a/torch/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -93,8 +93,13 @@ def _forward(self, input): bias_shape = [1] * len(self.weight.shape) bias_shape[1] = -1 scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape)) - # this does not include the conv bias - conv = self._conv_forward(input, scaled_weight) + # using zero bias here since the bias for original conv + # will be added later + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias) + else: + zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device) + conv = self._conv_forward(input, scaled_weight, zero_bias) conv_orig = conv / scale_factor.reshape(bias_shape) if self.bias is not None: conv_orig = conv_orig + self.bias.reshape(bias_shape) @@ -402,7 +407,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, def forward(self, input): return F.relu( - self._conv_forward(input, self.weight_fake_quant(self.weight))) + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod def from_float(cls, mod): diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index b801d990c4a6..33f2a84aed74 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -246,16 +246,16 @@ def __init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias, padding_mode) - def _conv_forward(self, input, weight): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != 'zeros': return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.bias, self.stride, + weight, bias, self.stride, _single(0), self.dilation, self.groups) - return F.conv1d(input, weight, self.bias, self.stride, + return F.conv1d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.weight) + return self._conv_forward(input, self.weight, self.bias) class Conv2d(_ConvNd): @@ -382,16 +382,16 @@ def __init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode) - def _conv_forward(self, input, weight): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != 'zeros': return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.bias, self.stride, + weight, bias, self.stride, _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.bias, self.stride, + return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.weight) + return self._conv_forward(input, self.weight, self.bias) class Conv3d(_ConvNd): __doc__ = r"""Applies a 3D convolution over an input signal composed of several input diff --git a/torch/nn/qat/modules/conv.py b/torch/nn/qat/modules/conv.py index a9c5f8547329..4b3814983347 100644 --- a/torch/nn/qat/modules/conv.py +++ b/torch/nn/qat/modules/conv.py @@ -29,7 +29,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, self.weight_fake_quant = qconfig.weight() def forward(self, input): - return self._conv_forward(input, self.weight_fake_quant(self.weight)) + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod def from_float(cls, mod):