diff --git a/test/quantization/test_qat_module.py b/test/quantization/test_qat_module.py index 4144c0744104..cb0eab60598c 100644 --- a/test/quantization/test_qat_module.py +++ b/test/quantization/test_qat_module.py @@ -52,6 +52,7 @@ def __init__(self, self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) self.activation_post_process = self.qconfig.activation() self.weight_fake_quant = self.qconfig.weight() + self.zero_bias = torch.zeros(out_channels) if bias: self.bias = nn.Parameter(torch.Tensor(out_channels)) else: @@ -110,7 +111,7 @@ def _forward(self, input): running_std = torch.sqrt(self.running_var + self.eps) scale_factor = self.gamma / running_std scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1]) - conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight)) + conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), self.zero_bias) if self.training and not self.freeze_bn: # recovering original conv to get original batch_mean and batch_var diff --git a/torch/nn/intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py index 659d284b2afd..6ab402528d71 100644 --- a/torch/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -43,6 +43,7 @@ def __init__(self, self.freeze_bn = freeze_bn if self.training else True self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) self.weight_fake_quant = self.qconfig.weight() + self.zero_bias = torch.zeros(out_channels) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: @@ -94,7 +95,7 @@ def _forward(self, input): bias_shape[1] = -1 scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape)) # this does not include the conv bias - conv = self._conv_forward(input, scaled_weight) + conv = self._conv_forward(input, scaled_weight, self.zero_bias) conv_orig = conv / scale_factor.reshape(bias_shape) if self.bias is not None: conv_orig = conv_orig + self.bias.reshape(bias_shape) @@ -402,7 +403,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, def forward(self, input): return F.relu( - self._conv_forward(input, self.weight_fake_quant(self.weight))) + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod def from_float(cls, mod): diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 8f94e6970c3e..f319aa1d733e 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -245,16 +245,16 @@ def __init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias, padding_mode) - def _conv_forward(self, input, weight): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != 'zeros': return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.bias, self.stride, + weight, bias, self.stride, _single(0), self.dilation, self.groups) - return F.conv1d(input, weight, self.bias, self.stride, + return F.conv1d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.weight) + return self._conv_forward(input, self.weight, self.bias) class Conv2d(_ConvNd): @@ -381,16 +381,16 @@ def __init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode) - def _conv_forward(self, input, weight): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != 'zeros': return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.bias, self.stride, + weight, bias, self.stride, _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.bias, self.stride, + return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.weight) + return self._conv_forward(input, self.weight, self.bias) class Conv3d(_ConvNd): __doc__ = r"""Applies a 3D convolution over an input signal composed of several input diff --git a/torch/nn/qat/modules/conv.py b/torch/nn/qat/modules/conv.py index a9c5f8547329..4b3814983347 100644 --- a/torch/nn/qat/modules/conv.py +++ b/torch/nn/qat/modules/conv.py @@ -29,7 +29,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, self.weight_fake_quant = qconfig.weight() def forward(self, input): - return self._conv_forward(input, self.weight_fake_quant(self.weight)) + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod def from_float(cls, mod):