From 35fa5022774f12d6379374cc546263da8597dce1 Mon Sep 17 00:00:00 2001 From: Zafar Date: Tue, 8 Dec 2020 02:43:58 -0800 Subject: [PATCH] [quant] Add reflection padding to conv ghstack-source-id: 15e3346a609a05791fdd935a4cf5d745396d9cda Pull Request resolved: https://github.com/pytorch/pytorch/pull/49011 --- test/quantization/test_quantized_module.py | 35 +++++---- torch/nn/quantized/modules/conv.py | 90 ++++++++++++++++++---- 2 files changed, 92 insertions(+), 33 deletions(-) diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index 70cca9ab1eee..c31f4493bb49 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -359,6 +359,7 @@ def _test_conv_api_impl( groups=st.integers(1, 4), kernel=st.integers(1, 7), stride=st.integers(1, 2), + pad_mode=st.sampled_from(['zeros', 'reflect']), pad=st.integers(0, 2), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), @@ -373,7 +374,7 @@ def _test_conv_api_impl( @override_qengines def test_conv1d_api( self, batch_size, in_channels_per_group, length, out_channels_per_group, - groups, kernel, stride, pad, dilation, + groups, kernel, stride, pad_mode, pad, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise ): @@ -391,16 +392,16 @@ def test_conv1d_api( module_name = "QuantizedConvReLU1d" qconv_module = nnq_fused.ConvReLU1d( in_channels, out_channels, kernel, stride, pad, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv1d" qconv_module = nnq.Conv1d( in_channels, out_channels, kernel, stride, pad, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv1d( in_channels, out_channels, kernel, stride, pad, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU1d(conv_module, relu_module) @@ -425,6 +426,7 @@ def test_conv1d_api( stride_w=st.integers(1, 2), pad_h=st.integers(0, 2), pad_w=st.integers(0, 2), + pad_mode=st.sampled_from(['zeros', 'reflect']), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), @@ -438,9 +440,9 @@ def test_conv1d_api( @override_qengines def test_conv2d_api( self, batch_size, in_channels_per_group, H, W, out_channels_per_group, - groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation, - X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, - use_bias, use_fused, use_channelwise + groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, pad_mode, + dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, + Y_zero_point, use_bias, use_fused, use_channelwise ): # Tests the correctness of the conv2d module. in_channels = in_channels_per_group * groups @@ -454,16 +456,16 @@ def test_conv2d_api( module_name = "QuantizedConvReLU2d" qconv_module = nnq_fused.ConvReLU2d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv2d" qconv_module = nnq.Conv2d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU2d(conv_module, relu_module) @@ -493,6 +495,7 @@ def test_conv2d_api( pad_d=st.integers(0, 1), pad_h=st.integers(0, 1), pad_w=st.integers(0, 1), + pad_mode=st.sampled_from(['zeros', 'reflect']), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), @@ -506,9 +509,9 @@ def test_conv2d_api( def test_conv3d_api( self, batch_size, in_channels_per_group, D, H, W, out_channels_per_group, groups, kernel_d, kernel_h, kernel_w, - stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale, - X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, - use_channelwise, use_fused, + stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, pad_mode, dilation, + X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, + use_bias, use_channelwise, use_fused, ): # Tests the correctness of the conv3d module. in_channels = in_channels_per_group * groups @@ -523,16 +526,16 @@ def test_conv3d_api( module_name = "QuantizedConvReLU3d" qconv_module = nnq_fused.ConvReLU3d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv3d" qconv_module = nnq.Conv3d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv3d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU3d(conv_module, relu_module) diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index a9ba3293630d..72b5d4dfd86d 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -5,26 +5,47 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.nn.intrinsic as nni import torch.nn.intrinsic.qat as nniqat from torch._ops import ops +from torch.nn.common_types import _size_1_t from torch.nn.modules.utils import _single, _pair, _triple from torch.nn.quantized.modules.utils import _pair_from_first from torch.nn.quantized.modules.utils import _quantize_weight from torch.nn.utils import fuse_conv_bn_weights +from typing import Optional, Tuple + +_SUPPORTED_PADDING = { + 'zeros', + 'reflect' +} + +def _reverse_repeat_padding(padding: List[int]) -> List[int]: + _reversed_padding_repeated_twice: List[int] = [] + N = len(padding) + for idx in range(N): # pad in padding: + for _ in range(2): + _reversed_padding_repeated_twice.append(padding[N-idx-1]) + return _reversed_padding_repeated_twice + class _ConvNd(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t, + padding: _size_1_t, + dilation: _size_1_t, + transposed: bool, + output_padding: _size_1_t, + groups: int, + bias: bool, + padding_mode: str='zeros'): - def __init__(self, in_channels, out_channels, kernel_size, stride, - padding, dilation, - transposed, output_padding, - groups, bias, - padding_mode='zeros'): super(_ConvNd, self).__init__() - if padding_mode != 'zeros': - raise NotImplementedError( - "Currently only zero-padding is supported by quantized conv") if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: @@ -38,6 +59,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, self.transposed = transposed self.output_padding = output_padding self.groups = groups + if padding_mode not in _SUPPORTED_PADDING: + raise ValueError("'padding_mode' {} is not supported by quantized convolution".format(padding_mode)) self.padding_mode = padding_mode # Initialize as NCHW. set_weight will internally transpose to NHWC. if self.transposed: @@ -225,9 +248,16 @@ class Conv1d(_ConvNd): _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d _NNI_CONV_RELU_MODULE = nni.ConvReLU1d - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True, - padding_mode='zeros'): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups:int = 1, + bias: bool = True, + padding_mode: str ='zeros'): kernel_size = _pair_from_first(kernel_size) stride = _pair_from_first(stride) padding = _pair_from_first(padding) @@ -242,8 +272,13 @@ def _get_name(self): def set_weight_bias(self, w, b): # type: (torch.Tensor, Optional[torch.Tensor]) -> None - self._packed_params = torch.ops.quantized.conv1d_prepack( - w, b, self.stride, self.padding, self.dilation, self.groups) + if self.padding_mode is 'zeros': + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, _pair_from_first(0), self.dilation, + self.groups) def _weight_bias(self): w, b = torch.ops.quantized.conv1d_unpack(self._packed_params) @@ -260,6 +295,10 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 3: raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point) @classmethod @@ -329,8 +368,12 @@ def _get_name(self): def set_weight_bias(self, w, b): # type: (torch.Tensor, Optional[torch.Tensor]) -> None - self._packed_params = torch.ops.quantized.conv2d_prepack( - w, b, self.stride, self.padding, self.dilation, self.groups) + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups) def _weight_bias(self): return self._packed_params.unpack() @@ -346,6 +389,11 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + print(self.padding, _reversed_padding_repeated_twice) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return ops.quantized.conv2d( input, self._packed_params, self.scale, self.zero_point) @@ -414,8 +462,12 @@ def _get_name(self): def set_weight_bias(self, w, b): # type: (torch.Tensor, Optional[torch.Tensor]) -> None - self._packed_params = torch.ops.quantized.conv3d_prepack( - w, b, self.stride, self.padding, self.dilation, self.groups) + if self.padding_mode is 'zeros': + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, _triple(0), self.dilation, self.groups) def _weight_bias(self): return self._packed_params.unpack() @@ -431,6 +483,10 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 5: raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return ops.quantized.conv3d( input, self._packed_params, self.scale, self.zero_point)