Skip to content

Commit

Permalink
[quant][fix] Add bias once in conv_fused (#48593)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #48593

Previously _conv_forward will add self.bias to the result, so bias is added twice in qat ConvBn module
this PR added a bias argument to _conv_forward and _conv_forward is called with zero bias
in ConvBn module

fixes: #48514

Test Plan: Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D25222215

fbshipit-source-id: 90c0ab79835b6d09622dcfec9de4139881a60746
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Dec 1, 2020
1 parent 7a59a1b commit d2e4298
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
3 changes: 2 additions & 1 deletion test/quantization/test_qat_module.py
Expand Up @@ -52,6 +52,7 @@ def __init__(self,
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.activation_post_process = self.qconfig.activation()
self.weight_fake_quant = self.qconfig.weight()
self.zero_bias = torch.zeros(out_channels)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
Expand Down Expand Up @@ -110,7 +111,7 @@ def _forward(self, input):
running_std = torch.sqrt(self.running_var + self.eps)
scale_factor = self.gamma / running_std
scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1])
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight))
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), self.zero_bias)

if self.training and not self.freeze_bn:
# recovering original conv to get original batch_mean and batch_var
Expand Down
5 changes: 3 additions & 2 deletions torch/nn/intrinsic/qat/modules/conv_fused.py
Expand Up @@ -43,6 +43,7 @@ def __init__(self,
self.freeze_bn = freeze_bn if self.training else True
self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
self.weight_fake_quant = self.qconfig.weight()
self.zero_bias = torch.zeros(out_channels)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
Expand Down Expand Up @@ -94,7 +95,7 @@ def _forward(self, input):
bias_shape[1] = -1
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
# this does not include the conv bias
conv = self._conv_forward(input, scaled_weight)
conv = self._conv_forward(input, scaled_weight, self.zero_bias)
conv_orig = conv / scale_factor.reshape(bias_shape)
if self.bias is not None:
conv_orig = conv_orig + self.bias.reshape(bias_shape)
Expand Down Expand Up @@ -402,7 +403,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,

def forward(self, input):
return F.relu(
self._conv_forward(input, self.weight_fake_quant(self.weight)))
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))

@classmethod
def from_float(cls, mod):
Expand Down
16 changes: 8 additions & 8 deletions torch/nn/modules/conv.py
Expand Up @@ -245,16 +245,16 @@ def __init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode)

def _conv_forward(self, input, weight):
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
weight, bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, weight, self.bias, self.stride,
return F.conv1d(input, weight, bias, self.stride,
self.padding, self.dilation, self.groups)

def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
return self._conv_forward(input, self.weight, self.bias)


class Conv2d(_ConvNd):
Expand Down Expand Up @@ -381,16 +381,16 @@ def __init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)

def _conv_forward(self, input, weight):
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
return F.conv2d(input, weight, bias, self.stride,
self.padding, self.dilation, self.groups)

def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
return self._conv_forward(input, self.weight, self.bias)

class Conv3d(_ConvNd):
__doc__ = r"""Applies a 3D convolution over an input signal composed of several input
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/qat/modules/conv.py
Expand Up @@ -29,7 +29,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
self.weight_fake_quant = qconfig.weight()

def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight))
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)

@classmethod
def from_float(cls, mod):
Expand Down

0 comments on commit d2e4298

Please sign in to comment.