Skip to content

Commit

Permalink
[quant][graphmode][fx] Add support for qat convbn{relu}1d (#47248)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47248

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D24696524

fbshipit-source-id: 684db12be201307acbdc89a44192cf2270491dba
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Nov 4, 2020
1 parent 3a00245 commit 0cba3e3
Show file tree
Hide file tree
Showing 11 changed files with 229 additions and 81 deletions.
58 changes: 55 additions & 3 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -103,14 +103,15 @@ def forward(self, x):
# train mode fuse_fx is called in prepare_qat_fx
m = prepare_qat_fx(m, {})
expected_nodes = [
ns.call_module(nni.ConvBn1d),
ns.call_module(nni.ConvBn2d),
ns.call_module(nni.ConvBn3d),
ns.call_module(nni.ConvBnReLU1d),
ns.call_module(nni.ConvBnReLU2d),
ns.call_module(nni.ConvBnReLU3d),
]
# ConvBnRelu1d is not fused
expected_occurrence = {
ns.call_module(nn.ReLU): 1
ns.call_module(nn.ReLU): 0
}
self.checkGraphModuleNodes(
m,
Expand All @@ -123,14 +124,16 @@ def forward(self, x):
# fuse_fx is a top level api and only supports eval mode
m = fuse_fx(m)
expected_nodes = [
ns.call_module(nn.Conv1d),
ns.call_module(nn.Conv2d),
ns.call_module(nn.Conv3d),
ns.call_module(nni.ConvReLU1d),
ns.call_module(nni.ConvReLU2d),
ns.call_module(nni.ConvReLU3d),
]
# ConvBnRelu1d is not fused
expected_occurrence = {
ns.call_module(nn.ReLU): 1
ns.call_module(nn.ReLU): 0
}
self.checkGraphModuleNodes(
m,
Expand Down Expand Up @@ -301,6 +304,55 @@ def forward(self, x):
ref_qparams = weight_obs.calculate_qparams()
self.assertEqual(qparams, ref_qparams)

def test_conv_bn_relu(self):
convs = {
1: nn.Conv1d,
2: nn.Conv2d,
3: nn.Conv3d,
}
bns = {
1: nn.BatchNorm1d,
2: nn.BatchNorm2d,
3: nn.BatchNorm3d,
}
quantized_convs = {
1: nnq.Conv1d,
2: nnq.Conv2d,
3: nnq.Conv3d,
}
quantized_conv_relus = {
1: nniq.ConvReLU1d,
2: nniq.ConvReLU2d,
3: nniq.ConvReLU3d,
}

class M(torch.nn.Module):
def __init__(self, dim, has_relu):
super().__init__()
self.conv = convs[dim](3, 3, 3)
self.bn = bns[dim](3)
self.relu = nn.ReLU()
self.has_relu = has_relu

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.has_relu:
x = self.relu(x)
return x

options = itertools.product([1, 2], [True, False], self.static_quant_types)
for dim, has_relu, quant_type in options:
expected_node = ns.call_module(
quantized_conv_relus[dim] if has_relu
else quantized_convs[dim])
self.checkGraphModeFxOp(
M(dim, has_relu),
self.img_data_dict[dim],
quant_type,
expected_node=expected_node,
)

@skipIfNoFBGEMM
def test_dynamic_quant_fp16(self):
class Linear(torch.nn.Module):
Expand Down
15 changes: 1 addition & 14 deletions torch/nn/intrinsic/qat/__init__.py
@@ -1,14 +1 @@
from .modules import LinearReLU
from .modules import ConvReLU2d
from .modules import ConvBn2d
from .modules import ConvBnReLU2d
from .modules import update_bn_stats, freeze_bn_stats

__all__ = [
'ConvBn2d',
'ConvBnReLU2d',
'ConvReLU2d',
'LinearReLU',
'update_bn_stats',
'freeze_bn_stats'
]
from .modules import *
6 changes: 4 additions & 2 deletions torch/nn/intrinsic/qat/modules/__init__.py
@@ -1,11 +1,13 @@

from .linear_relu import LinearReLU
from .conv_fused import ConvBn2d, ConvBnReLU2d, ConvReLU2d, update_bn_stats, freeze_bn_stats
from .conv_fused import ConvBn1d, ConvBn2d, ConvBnReLU1d, ConvBnReLU2d, ConvReLU2d, \
update_bn_stats, freeze_bn_stats

__all__ = [
'LinearReLU',
'ConvReLU2d',
'ConvBn1d',
'ConvBn2d',
'ConvBnReLU1d',
'ConvBnReLU2d',
'update_bn_stats',
'freeze_bn_stats'
Expand Down
123 changes: 108 additions & 15 deletions torch/nn/intrinsic/qat/modules/conv_fused.py
Expand Up @@ -5,9 +5,16 @@
import torch.nn.qat as nnqat
import torch.nn.functional as F
from torch.nn import init
from torch.nn.modules.utils import _pair
from torch.nn.modules.utils import _single, _pair
from torch.nn.parameter import Parameter

_BN_CLASS_MAP = {
1: nn.BatchNorm1d,
2: nn.BatchNorm2d,
3: nn.BatchNorm3d,
}


class _ConvBnNd(nn.modules.conv._ConvNd):

_version = 2
Expand All @@ -26,14 +33,15 @@ def __init__(self,
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
qconfig=None,
dim=2):
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, transposed,
output_padding, groups, False, padding_mode)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.freeze_bn = freeze_bn if self.training else True
self.bn = nn.BatchNorm2d(out_channels, eps, momentum, True, True)
self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
Expand Down Expand Up @@ -80,12 +88,16 @@ def freeze_bn_stats(self):
def _forward(self, input):
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape([-1, 1, 1, 1]))
weight_shape = [1] * len(self.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.weight.shape)
bias_shape[1] = -1
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
# this does not include the conv bias
conv = self._conv_forward(input, scaled_weight)
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
conv_orig = conv / scale_factor.reshape(bias_shape)
if self.bias is not None:
conv_orig = conv_orig + self.bias.reshape([1, -1, 1, 1])
conv_orig = conv_orig + self.bias.reshape(bias_shape)
conv = self.bn(conv_orig)
return conv

Expand Down Expand Up @@ -190,6 +202,92 @@ def from_float(cls, mod):
qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked
return qat_convbn

class ConvBn1d(_ConvBnNd, nn.Conv1d):
r"""
A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d`.
Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = torch.nn.intrinsic.ConvBn1d

def __init__(self,
# Conv1d args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm1d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
_ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, False, _single(0), groups, bias, padding_mode,
eps, momentum, freeze_bn, qconfig, dim=1)

class ConvBnReLU1d(ConvBn1d):
r"""
A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
attached with FakeQuantize modules for weight,
used in quantization aware training.
We combined the interface of :class:`torch.nn.Conv1d` and
:class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = torch.nn.intrinsic.ConvBnReLU1d

def __init__(self,
# Conv1d args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm1d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias,
padding_mode, eps, momentum,
freeze_bn,
qconfig)

def forward(self, input):
return F.relu(ConvBn1d._forward(self, input))

@classmethod
def from_float(cls, mod):
return super(ConvBnReLU1d, cls).from_float(mod)

class ConvBn2d(_ConvBnNd, nn.Conv2d):
r"""
A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
Expand All @@ -199,8 +297,6 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d):
We combined the interface of :class:`torch.nn.Conv2d` and
:class:`torch.nn.BatchNorm2d`.
Implementation details: https://arxiv.org/pdf/1806.08342.pdf section 3.2.2
Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
to default.
Expand Down Expand Up @@ -231,7 +327,7 @@ def __init__(self,
dilation = _pair(dilation)
_ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, False, _pair(0), groups, bias, padding_mode,
eps, momentum, freeze_bn, qconfig)
eps, momentum, freeze_bn, qconfig, dim=2)

class ConvBnReLU2d(ConvBn2d):
r"""
Expand All @@ -242,8 +338,6 @@ class ConvBnReLU2d(ConvBn2d):
We combined the interface of :class:`torch.nn.Conv2d` and
:class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
Implementation details: https://arxiv.org/pdf/1806.08342.pdf
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Expand Down Expand Up @@ -281,8 +375,7 @@ def from_float(cls, mod):
return super(ConvBnReLU2d, cls).from_float(mod)

class ConvReLU2d(nnqat.Conv2d):
r"""
A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for weight for
quantization aware training.
Expand Down Expand Up @@ -316,9 +409,9 @@ def from_float(cls, mod):
return super(ConvReLU2d, cls).from_float(mod)

def update_bn_stats(mod):
if type(mod) in set([ConvBnReLU2d, ConvBn2d]):
if type(mod) in set([ConvBnReLU1d, ConvBnReLU2d, ConvBn1d, ConvBn2d]):
mod.update_bn_stats()

def freeze_bn_stats(mod):
if type(mod) in set([ConvBnReLU2d, ConvBn2d]):
if type(mod) in set([ConvBnReLU1d, ConvBnReLU2d, ConvBn1d, ConvBn2d]):
mod.freeze_bn_stats()
9 changes: 6 additions & 3 deletions torch/nn/modules/conv.py
Expand Up @@ -245,14 +245,17 @@ def __init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode)

def forward(self, input: Tensor) -> Tensor:
def _conv_forward(self, input, weight):
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
self.weight, self.bias, self.stride,
weight, self.bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, self.weight, self.bias, self.stride,
return F.conv1d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)

def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight)


class Conv2d(_ConvNd):
__doc__ = r"""Applies a 2D convolution over an input signal composed of several input
Expand Down

0 comments on commit 0cba3e3

Please sign in to comment.