Skip to content

Commit

Permalink
[reland][quant][fix] Add bias once in conv_fused (#48593) (#48661)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #48661

Previously _conv_forward will add self.bias to the result, so bias is added twice in qat ConvBn module
this PR added a bias argument to _conv_forward and _conv_forward is called with zero bias
in ConvBn module

fixes: #48514

Test Plan:
Imported from OSS

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D25249175

fbshipit-source-id: 4536c7545d3dcd7e8ea254368ffb7cf15118d78c
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Dec 2, 2020
1 parent 0db7346 commit 52f0af0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
6 changes: 5 additions & 1 deletion test/quantization/test_qat_module.py
Expand Up @@ -110,7 +110,11 @@ def _forward(self, input):
running_std = torch.sqrt(self.running_var + self.eps)
scale_factor = self.gamma / running_std
scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1])
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight))
if self.bias is not None:
zero_bias = torch.zeros_like(self.bias)
else:
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device)
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias)

if self.training and not self.freeze_bn:
# recovering original conv to get original batch_mean and batch_var
Expand Down
11 changes: 8 additions & 3 deletions torch/nn/intrinsic/qat/modules/conv_fused.py
Expand Up @@ -93,8 +93,13 @@ def _forward(self, input):
bias_shape = [1] * len(self.weight.shape)
bias_shape[1] = -1
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
# this does not include the conv bias
conv = self._conv_forward(input, scaled_weight)
# using zero bias here since the bias for original conv
# will be added later
if self.bias is not None:
zero_bias = torch.zeros_like(self.bias)
else:
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device)
conv = self._conv_forward(input, scaled_weight, zero_bias)
conv_orig = conv / scale_factor.reshape(bias_shape)
if self.bias is not None:
conv_orig = conv_orig + self.bias.reshape(bias_shape)
Expand Down Expand Up @@ -402,7 +407,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,

def forward(self, input):
return F.relu(
self._conv_forward(input, self.weight_fake_quant(self.weight)))
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))

@classmethod
def from_float(cls, mod):
Expand Down
16 changes: 8 additions & 8 deletions torch/nn/modules/conv.py
Expand Up @@ -246,16 +246,16 @@ def __init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode)

def _conv_forward(self, input, weight):
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
weight, bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, weight, self.bias, self.stride,
return F.conv1d(input, weight, bias, self.stride,
self.padding, self.dilation, self.groups)

def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
return self._conv_forward(input, self.weight, self.bias)


class Conv2d(_ConvNd):
Expand Down Expand Up @@ -382,16 +382,16 @@ def __init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)

def _conv_forward(self, input, weight):
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
return F.conv2d(input, weight, bias, self.stride,
self.padding, self.dilation, self.groups)

def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
return self._conv_forward(input, self.weight, self.bias)

class Conv3d(_ConvNd):
__doc__ = r"""Applies a 3D convolution over an input signal composed of several input
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/qat/modules/conv.py
Expand Up @@ -29,7 +29,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
self.weight_fake_quant = qconfig.weight()

def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight))
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)

@classmethod
def from_float(cls, mod):
Expand Down

0 comments on commit 52f0af0

Please sign in to comment.