From 5708ee1695bb8d95193681a830eb4dc27b7a2796 Mon Sep 17 00:00:00 2001 From: Zafar Date: Sun, 20 Dec 2020 17:44:08 -0800 Subject: [PATCH] [quant] Add reflection padding to conv ghstack-source-id: 285186d6c359dae177898d8a2c91b7252e72ba76 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49011 --- test/quantization/test_quantized_module.py | 35 +++++---- torch/nn/quantized/modules/conv.py | 89 +++++++++++++++++----- 2 files changed, 91 insertions(+), 33 deletions(-) diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index 70cca9ab1eee..c31f4493bb49 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -359,6 +359,7 @@ def _test_conv_api_impl( groups=st.integers(1, 4), kernel=st.integers(1, 7), stride=st.integers(1, 2), + pad_mode=st.sampled_from(['zeros', 'reflect']), pad=st.integers(0, 2), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), @@ -373,7 +374,7 @@ def _test_conv_api_impl( @override_qengines def test_conv1d_api( self, batch_size, in_channels_per_group, length, out_channels_per_group, - groups, kernel, stride, pad, dilation, + groups, kernel, stride, pad_mode, pad, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise ): @@ -391,16 +392,16 @@ def test_conv1d_api( module_name = "QuantizedConvReLU1d" qconv_module = nnq_fused.ConvReLU1d( in_channels, out_channels, kernel, stride, pad, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv1d" qconv_module = nnq.Conv1d( in_channels, out_channels, kernel, stride, pad, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv1d( in_channels, out_channels, kernel, stride, pad, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU1d(conv_module, relu_module) @@ -425,6 +426,7 @@ def test_conv1d_api( stride_w=st.integers(1, 2), pad_h=st.integers(0, 2), pad_w=st.integers(0, 2), + pad_mode=st.sampled_from(['zeros', 'reflect']), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), @@ -438,9 +440,9 @@ def test_conv1d_api( @override_qengines def test_conv2d_api( self, batch_size, in_channels_per_group, H, W, out_channels_per_group, - groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation, - X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, - use_bias, use_fused, use_channelwise + groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, pad_mode, + dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, + Y_zero_point, use_bias, use_fused, use_channelwise ): # Tests the correctness of the conv2d module. in_channels = in_channels_per_group * groups @@ -454,16 +456,16 @@ def test_conv2d_api( module_name = "QuantizedConvReLU2d" qconv_module = nnq_fused.ConvReLU2d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv2d" qconv_module = nnq.Conv2d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU2d(conv_module, relu_module) @@ -493,6 +495,7 @@ def test_conv2d_api( pad_d=st.integers(0, 1), pad_h=st.integers(0, 1), pad_w=st.integers(0, 1), + pad_mode=st.sampled_from(['zeros', 'reflect']), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), @@ -506,9 +509,9 @@ def test_conv2d_api( def test_conv3d_api( self, batch_size, in_channels_per_group, D, H, W, out_channels_per_group, groups, kernel_d, kernel_h, kernel_w, - stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale, - X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, - use_channelwise, use_fused, + stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, pad_mode, dilation, + X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, + use_bias, use_channelwise, use_fused, ): # Tests the correctness of the conv3d module. in_channels = in_channels_per_group * groups @@ -523,16 +526,16 @@ def test_conv3d_api( module_name = "QuantizedConvReLU3d" qconv_module = nnq_fused.ConvReLU3d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv3d" qconv_module = nnq.Conv3d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv3d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU3d(conv_module, relu_module) diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index a9ba3293630d..5c80c3374881 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -5,26 +5,46 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.nn.intrinsic as nni import torch.nn.intrinsic.qat as nniqat from torch._ops import ops +from torch.nn.common_types import _size_1_t from torch.nn.modules.utils import _single, _pair, _triple from torch.nn.quantized.modules.utils import _pair_from_first from torch.nn.quantized.modules.utils import _quantize_weight from torch.nn.utils import fuse_conv_bn_weights +_SUPPORTED_PADDING = { + 'zeros', + 'reflect' +} + + +def _reverse_repeat_padding(padding: List[int]) -> List[int]: + _reversed_padding_repeated_twice: List[int] = [] + N = len(padding) + for idx in range(N): + for _ in range(2): + _reversed_padding_repeated_twice.append(padding[N - idx - 1]) + return _reversed_padding_repeated_twice + class _ConvNd(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t, + padding: _size_1_t, + dilation: _size_1_t, + transposed: bool, + output_padding: _size_1_t, + groups: int, + bias: bool, + padding_mode: str = 'zeros'): - def __init__(self, in_channels, out_channels, kernel_size, stride, - padding, dilation, - transposed, output_padding, - groups, bias, - padding_mode='zeros'): super(_ConvNd, self).__init__() - if padding_mode != 'zeros': - raise NotImplementedError( - "Currently only zero-padding is supported by quantized conv") if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: @@ -38,6 +58,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, self.transposed = transposed self.output_padding = output_padding self.groups = groups + if padding_mode not in _SUPPORTED_PADDING: + raise ValueError("'padding_mode' {} is not supported by quantized convolution".format(padding_mode)) self.padding_mode = padding_mode # Initialize as NCHW. set_weight will internally transpose to NHWC. if self.transposed: @@ -225,9 +247,16 @@ class Conv1d(_ConvNd): _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d _NNI_CONV_RELU_MODULE = nni.ConvReLU1d - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True, - padding_mode='zeros'): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros'): kernel_size = _pair_from_first(kernel_size) stride = _pair_from_first(stride) padding = _pair_from_first(padding) @@ -242,8 +271,13 @@ def _get_name(self): def set_weight_bias(self, w, b): # type: (torch.Tensor, Optional[torch.Tensor]) -> None - self._packed_params = torch.ops.quantized.conv1d_prepack( - w, b, self.stride, self.padding, self.dilation, self.groups) + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, _pair_from_first(0), self.dilation, + self.groups) def _weight_bias(self): w, b = torch.ops.quantized.conv1d_unpack(self._packed_params) @@ -260,6 +294,10 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 3: raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point) @classmethod @@ -329,8 +367,12 @@ def _get_name(self): def set_weight_bias(self, w, b): # type: (torch.Tensor, Optional[torch.Tensor]) -> None - self._packed_params = torch.ops.quantized.conv2d_prepack( - w, b, self.stride, self.padding, self.dilation, self.groups) + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups) def _weight_bias(self): return self._packed_params.unpack() @@ -346,6 +388,11 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + print(self.padding, _reversed_padding_repeated_twice) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return ops.quantized.conv2d( input, self._packed_params, self.scale, self.zero_point) @@ -414,8 +461,12 @@ def _get_name(self): def set_weight_bias(self, w, b): # type: (torch.Tensor, Optional[torch.Tensor]) -> None - self._packed_params = torch.ops.quantized.conv3d_prepack( - w, b, self.stride, self.padding, self.dilation, self.groups) + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, _triple(0), self.dilation, self.groups) def _weight_bias(self): return self._packed_params.unpack() @@ -431,6 +482,10 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 5: raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return ops.quantized.conv3d( input, self._packed_params, self.scale, self.zero_point)