Skip to content

Commit

Permalink
[quant] Add reflection padding to conv
Browse files Browse the repository at this point in the history
ghstack-source-id: 285186d6c359dae177898d8a2c91b7252e72ba76
Pull Request resolved: #49011
  • Loading branch information
z-a-f committed Dec 21, 2020
1 parent 6568572 commit 5708ee1
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 33 deletions.
35 changes: 19 additions & 16 deletions test/quantization/test_quantized_module.py
Expand Up @@ -359,6 +359,7 @@ def _test_conv_api_impl(
groups=st.integers(1, 4),
kernel=st.integers(1, 7),
stride=st.integers(1, 2),
pad_mode=st.sampled_from(['zeros', 'reflect']),
pad=st.integers(0, 2),
dilation=st.integers(1, 2),
X_scale=st.floats(1.2, 1.6),
Expand All @@ -373,7 +374,7 @@ def _test_conv_api_impl(
@override_qengines
def test_conv1d_api(
self, batch_size, in_channels_per_group, length, out_channels_per_group,
groups, kernel, stride, pad, dilation,
groups, kernel, stride, pad_mode, pad, dilation,
X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
use_bias, use_fused, use_channelwise
):
Expand All @@ -391,16 +392,16 @@ def test_conv1d_api(
module_name = "QuantizedConvReLU1d"
qconv_module = nnq_fused.ConvReLU1d(
in_channels, out_channels, kernel, stride, pad,
dilation, groups, use_bias, padding_mode="zeros")
dilation, groups, use_bias, padding_mode=pad_mode)
else:
module_name = "QuantizedConv1d"
qconv_module = nnq.Conv1d(
in_channels, out_channels, kernel, stride, pad,
dilation, groups, use_bias, padding_mode="zeros")
dilation, groups, use_bias, padding_mode=pad_mode)

conv_module = nn.Conv1d(
in_channels, out_channels, kernel, stride, pad,
dilation, groups, use_bias, padding_mode="zeros")
dilation, groups, use_bias, padding_mode=pad_mode)
if use_fused:
relu_module = nn.ReLU()
conv_module = nni.ConvReLU1d(conv_module, relu_module)
Expand All @@ -425,6 +426,7 @@ def test_conv1d_api(
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
pad_mode=st.sampled_from(['zeros', 'reflect']),
dilation=st.integers(1, 2),
X_scale=st.floats(1.2, 1.6),
X_zero_point=st.integers(0, 4),
Expand All @@ -438,9 +440,9 @@ def test_conv1d_api(
@override_qengines
def test_conv2d_api(
self, batch_size, in_channels_per_group, H, W, out_channels_per_group,
groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation,
X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
use_bias, use_fused, use_channelwise
groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, pad_mode,
dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
Y_zero_point, use_bias, use_fused, use_channelwise
):
# Tests the correctness of the conv2d module.
in_channels = in_channels_per_group * groups
Expand All @@ -454,16 +456,16 @@ def test_conv2d_api(
module_name = "QuantizedConvReLU2d"
qconv_module = nnq_fused.ConvReLU2d(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode="zeros")
dilation, groups, use_bias, padding_mode=pad_mode)
else:
module_name = "QuantizedConv2d"
qconv_module = nnq.Conv2d(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode="zeros")
dilation, groups, use_bias, padding_mode=pad_mode)

conv_module = nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode="zeros")
dilation, groups, use_bias, padding_mode=pad_mode)
if use_fused:
relu_module = nn.ReLU()
conv_module = nni.ConvReLU2d(conv_module, relu_module)
Expand Down Expand Up @@ -493,6 +495,7 @@ def test_conv2d_api(
pad_d=st.integers(0, 1),
pad_h=st.integers(0, 1),
pad_w=st.integers(0, 1),
pad_mode=st.sampled_from(['zeros', 'reflect']),
dilation=st.integers(1, 2),
X_scale=st.floats(1.2, 1.6),
X_zero_point=st.integers(0, 4),
Expand All @@ -506,9 +509,9 @@ def test_conv2d_api(
def test_conv3d_api(
self, batch_size, in_channels_per_group, D, H, W,
out_channels_per_group, groups, kernel_d, kernel_h, kernel_w,
stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale,
X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
use_channelwise, use_fused,
stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, pad_mode, dilation,
X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
use_bias, use_channelwise, use_fused,
):
# Tests the correctness of the conv3d module.
in_channels = in_channels_per_group * groups
Expand All @@ -523,16 +526,16 @@ def test_conv3d_api(
module_name = "QuantizedConvReLU3d"
qconv_module = nnq_fused.ConvReLU3d(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode="zeros")
dilation, groups, use_bias, padding_mode=pad_mode)
else:
module_name = "QuantizedConv3d"
qconv_module = nnq.Conv3d(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode="zeros")
dilation, groups, use_bias, padding_mode=pad_mode)

conv_module = nn.Conv3d(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode="zeros")
dilation, groups, use_bias, padding_mode=pad_mode)
if use_fused:
relu_module = nn.ReLU()
conv_module = nni.ConvReLU3d(conv_module, relu_module)
Expand Down
89 changes: 72 additions & 17 deletions torch/nn/quantized/modules/conv.py
Expand Up @@ -5,26 +5,46 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.qat as nniqat

from torch._ops import ops
from torch.nn.common_types import _size_1_t
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.quantized.modules.utils import _pair_from_first
from torch.nn.quantized.modules.utils import _quantize_weight
from torch.nn.utils import fuse_conv_bn_weights

_SUPPORTED_PADDING = {
'zeros',
'reflect'
}


def _reverse_repeat_padding(padding: List[int]) -> List[int]:
_reversed_padding_repeated_twice: List[int] = []
N = len(padding)
for idx in range(N):
for _ in range(2):
_reversed_padding_repeated_twice.append(padding[N - idx - 1])
return _reversed_padding_repeated_twice

class _ConvNd(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t,
padding: _size_1_t,
dilation: _size_1_t,
transposed: bool,
output_padding: _size_1_t,
groups: int,
bias: bool,
padding_mode: str = 'zeros'):

def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation,
transposed, output_padding,
groups, bias,
padding_mode='zeros'):
super(_ConvNd, self).__init__()
if padding_mode != 'zeros':
raise NotImplementedError(
"Currently only zero-padding is supported by quantized conv")
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
Expand All @@ -38,6 +58,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
if padding_mode not in _SUPPORTED_PADDING:
raise ValueError("'padding_mode' {} is not supported by quantized convolution".format(padding_mode))
self.padding_mode = padding_mode
# Initialize as NCHW. set_weight will internally transpose to NHWC.
if self.transposed:
Expand Down Expand Up @@ -225,9 +247,16 @@ class Conv1d(_ConvNd):
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d
_NNI_CONV_RELU_MODULE = nni.ConvReLU1d

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros'):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros'):
kernel_size = _pair_from_first(kernel_size)
stride = _pair_from_first(stride)
padding = _pair_from_first(padding)
Expand All @@ -242,8 +271,13 @@ def _get_name(self):

def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
self._packed_params = torch.ops.quantized.conv1d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)
if self.padding_mode == 'zeros':
self._packed_params = torch.ops.quantized.conv1d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)
else:
self._packed_params = torch.ops.quantized.conv1d_prepack(
w, b, self.stride, _pair_from_first(0), self.dilation,
self.groups)

def _weight_bias(self):
w, b = torch.ops.quantized.conv1d_unpack(self._packed_params)
Expand All @@ -260,6 +294,10 @@ def forward(self, input):
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)

@classmethod
Expand Down Expand Up @@ -329,8 +367,12 @@ def _get_name(self):

def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
self._packed_params = torch.ops.quantized.conv2d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)
if self.padding_mode == 'zeros':
self._packed_params = torch.ops.quantized.conv2d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)
else:
self._packed_params = torch.ops.quantized.conv2d_prepack(
w, b, self.stride, _pair(0), self.dilation, self.groups)

def _weight_bias(self):
return self._packed_params.unpack()
Expand All @@ -346,6 +388,11 @@ def forward(self, input):
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
print(self.padding, _reversed_padding_repeated_twice)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv2d(
input, self._packed_params, self.scale, self.zero_point)

Expand Down Expand Up @@ -414,8 +461,12 @@ def _get_name(self):

def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
self._packed_params = torch.ops.quantized.conv3d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)
if self.padding_mode == 'zeros':
self._packed_params = torch.ops.quantized.conv3d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)
else:
self._packed_params = torch.ops.quantized.conv3d_prepack(
w, b, self.stride, _triple(0), self.dilation, self.groups)

def _weight_bias(self):
return self._packed_params.unpack()
Expand All @@ -431,6 +482,10 @@ def forward(self, input):
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv3d(
input, self._packed_params, self.scale, self.zero_point)

Expand Down

0 comments on commit 5708ee1

Please sign in to comment.