Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][fix] Add bias once in conv_fused #48593

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/quantization/test_qat_module.py
Expand Up @@ -52,6 +52,7 @@ def __init__(self,
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.activation_post_process = self.qconfig.activation()
self.weight_fake_quant = self.qconfig.weight()
self.zero_bias = nn.Parameter(torch.Tensor(out_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
Expand Down Expand Up @@ -110,7 +111,7 @@ def _forward(self, input):
running_std = torch.sqrt(self.running_var + self.eps)
scale_factor = self.gamma / running_std
scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1])
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight))
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), self.zero_bias)

if self.training and not self.freeze_bn:
# recovering original conv to get original batch_mean and batch_var
Expand Down
5 changes: 3 additions & 2 deletions torch/nn/intrinsic/qat/modules/conv_fused.py
Expand Up @@ -43,6 +43,7 @@ def __init__(self,
self.freeze_bn = freeze_bn if self.training else True
self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
self.weight_fake_quant = self.qconfig.weight()
self.zero_bias = Parameter(torch.Tensor(out_channels))
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
Expand Down Expand Up @@ -94,7 +95,7 @@ def _forward(self, input):
bias_shape[1] = -1
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
# this does not include the conv bias
conv = self._conv_forward(input, scaled_weight)
conv = self._conv_forward(input, scaled_weight, self.zero_bias)
conv_orig = conv / scale_factor.reshape(bias_shape)
if self.bias is not None:
conv_orig = conv_orig + self.bias.reshape(bias_shape)
Expand Down Expand Up @@ -402,7 +403,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,

def forward(self, input):
return F.relu(
self._conv_forward(input, self.weight_fake_quant(self.weight)))
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))

@classmethod
def from_float(cls, mod):
Expand Down
16 changes: 8 additions & 8 deletions torch/nn/modules/conv.py
Expand Up @@ -245,16 +245,16 @@ def __init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode)

def _conv_forward(self, input, weight):
def _conv_forward(self, input, weight, bias):
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
weight, bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, weight, self.bias, self.stride,
return F.conv1d(input, weight, bias, self.stride,
self.padding, self.dilation, self.groups)

def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
return self._conv_forward(input, self.weight, self.bias)


class Conv2d(_ConvNd):
Expand Down Expand Up @@ -381,16 +381,16 @@ def __init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)

def _conv_forward(self, input, weight):
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
return F.conv2d(input, weight, bias, self.stride,
self.padding, self.dilation, self.groups)

def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)
return self._conv_forward(input, self.weight, self.bias)

class Conv3d(_ConvNd):
__doc__ = r"""Applies a 3D convolution over an input signal composed of several input
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/qat/modules/conv.py
Expand Up @@ -29,7 +29,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
self.weight_fake_quant = qconfig.weight()

def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight))
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)

@classmethod
def from_float(cls, mod):
Expand Down