From 7f02945655ff20617854c1c677a2d1ca66c7d1c8 Mon Sep 17 00:00:00 2001 From: Zafar Date: Wed, 3 Feb 2021 14:06:09 -0800 Subject: [PATCH] [quant] Add reflection padding to conv ghstack-source-id: 9bd5b9350d062c97b15f77a315ffdea94899cc3d Pull Request resolved: https://github.com/pytorch/pytorch/pull/49011 --- test/quantization/test_quantized_module.py | 35 ++++++----- torch/nn/quantized/modules/conv.py | 73 ++++++++++++++++++---- 2 files changed, 79 insertions(+), 29 deletions(-) diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index 28867e8260b6..7588ceec1442 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -363,6 +363,7 @@ def _test_conv_api_impl( groups=st.integers(1, 4), kernel=st.integers(1, 7), stride=st.integers(1, 2), + pad_mode=st.sampled_from(['zeros', 'reflect']), pad=st.integers(0, 2), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), @@ -377,7 +378,7 @@ def _test_conv_api_impl( @override_qengines def test_conv1d_api( self, batch_size, in_channels_per_group, length, out_channels_per_group, - groups, kernel, stride, pad, dilation, + groups, kernel, stride, pad_mode, pad, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise ): @@ -395,16 +396,16 @@ def test_conv1d_api( module_name = "QuantizedConvReLU1d" qconv_module = nnq_fused.ConvReLU1d( in_channels, out_channels, kernel, stride, pad, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv1d" qconv_module = nnq.Conv1d( in_channels, out_channels, kernel, stride, pad, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv1d( in_channels, out_channels, kernel, stride, pad, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU1d(conv_module, relu_module) @@ -429,6 +430,7 @@ def test_conv1d_api( stride_w=st.integers(1, 2), pad_h=st.integers(0, 2), pad_w=st.integers(0, 2), + pad_mode=st.sampled_from(['zeros', 'reflect']), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), @@ -442,9 +444,9 @@ def test_conv1d_api( @override_qengines def test_conv2d_api( self, batch_size, in_channels_per_group, H, W, out_channels_per_group, - groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation, - X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, - use_bias, use_fused, use_channelwise + groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, pad_mode, + dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, + Y_zero_point, use_bias, use_fused, use_channelwise ): # Tests the correctness of the conv2d module. in_channels = in_channels_per_group * groups @@ -458,16 +460,16 @@ def test_conv2d_api( module_name = "QuantizedConvReLU2d" qconv_module = nnq_fused.ConvReLU2d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv2d" qconv_module = nnq.Conv2d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU2d(conv_module, relu_module) @@ -497,6 +499,7 @@ def test_conv2d_api( pad_d=st.integers(0, 1), pad_h=st.integers(0, 1), pad_w=st.integers(0, 1), + pad_mode=st.sampled_from(['zeros', 'reflect']), dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), @@ -510,9 +513,9 @@ def test_conv2d_api( def test_conv3d_api( self, batch_size, in_channels_per_group, D, H, W, out_channels_per_group, groups, kernel_d, kernel_h, kernel_w, - stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale, - X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, - use_channelwise, use_fused, + stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, pad_mode, dilation, + X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, + use_bias, use_channelwise, use_fused, ): # Tests the correctness of the conv3d module. in_channels = in_channels_per_group * groups @@ -527,16 +530,16 @@ def test_conv3d_api( module_name = "QuantizedConvReLU3d" qconv_module = nnq_fused.ConvReLU3d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv3d" qconv_module = nnq.Conv3d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv3d( in_channels, out_channels, kernel_size, stride, padding, - dilation, groups, use_bias, padding_mode="zeros") + dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU3d(conv_module, relu_module) diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index b3bc78ff6941..5a4b30a13ce4 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -5,17 +5,32 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.nn.intrinsic as nni import torch.nn.intrinsic.qat as nniqat from torch._ops import ops +from torch.nn.common_types import _size_1_t from torch.nn.modules.utils import _single, _pair, _triple from torch.nn.quantized.modules.utils import _pair_from_first from torch.nn.quantized.modules.utils import _quantize_weight from torch.nn.utils import fuse_conv_bn_weights -class _ConvNd(nn.Module): +_SUPPORTED_PADDING = { + 'zeros', + 'reflect' +} + + +def _reverse_repeat_padding(padding: List[int]) -> List[int]: + _reversed_padding_repeated_twice: List[int] = [] + N = len(padding) + for idx in range(N): + for _ in range(2): + _reversed_padding_repeated_twice.append(padding[N - idx - 1]) + return _reversed_padding_repeated_twice +class _ConvNd(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): @@ -28,9 +43,7 @@ def _init(self, in_channels, out_channels, kernel_size, stride, groups, bias, padding_mode='zeros'): super(_ConvNd, self).__init__() - if padding_mode != 'zeros': - raise NotImplementedError( - "Currently only zero-padding is supported by quantized conv") + if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: @@ -44,6 +57,8 @@ def _init(self, in_channels, out_channels, kernel_size, stride, self.transposed = transposed self.output_padding = output_padding self.groups = groups + if padding_mode not in _SUPPORTED_PADDING: + raise ValueError("'padding_mode' {} is not supported by quantized convolution".format(padding_mode)) self.padding_mode = padding_mode # Initialize as NCHW. set_weight will internally transpose to NHWC. if self.transposed: @@ -241,9 +256,16 @@ class Conv1d(_ConvNd): _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d _NNI_CONV_RELU_MODULE = nni.ConvReLU1d - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True, - padding_mode='zeros'): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros'): kernel_size = _pair_from_first(kernel_size) stride = _pair_from_first(stride) padding = _pair_from_first(padding) @@ -259,8 +281,13 @@ def _get_name(self): return 'QuantizedConv1d' def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: - self._packed_params = torch.ops.quantized.conv1d_prepack( - w, b, self.stride, self.padding, self.dilation, self.groups) + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, _pair_from_first(0), self.dilation, + self.groups) def _weight_bias(self): w, b = torch.ops.quantized.conv1d_unpack(self._packed_params) @@ -277,6 +304,10 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 3: raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point) @classmethod @@ -347,8 +378,12 @@ def _get_name(self): return 'QuantizedConv2d' def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: - self._packed_params = torch.ops.quantized.conv2d_prepack( - w, b, self.stride, self.padding, self.dilation, self.groups) + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups) def _weight_bias(self): return self._packed_params.unpack() @@ -364,6 +399,10 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return ops.quantized.conv2d( input, self._packed_params, self.scale, self.zero_point) @@ -433,8 +472,12 @@ def _get_name(self): return 'QuantizedConv3d' def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: - self._packed_params = torch.ops.quantized.conv3d_prepack( - w, b, self.stride, self.padding, self.dilation, self.groups) + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, _triple(0), self.dilation, self.groups) def _weight_bias(self): return self._packed_params.unpack() @@ -450,6 +493,10 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 5: raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return ops.quantized.conv3d( input, self._packed_params, self.scale, self.zero_point)