diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index 7588ceec1442..07f6bcbfbbbe 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -230,9 +230,9 @@ def test_quant_dequant_api(self): def _test_conv_api_impl( self, module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, - groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, - W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, - use_channelwise, + groups, kernel_size, stride, padding, padding_mode, dilation, + X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, + use_bias, use_fused, use_channelwise, ): for i in range(len(kernel_size)): assume(input_feature_map_size[i] + 2 * padding[i] @@ -304,7 +304,7 @@ def _test_conv_api_impl( self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qconv_module = type(qconv_module)( in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, use_bias, padding_mode="zeros") + groups, use_bias, padding_mode=padding_mode) loaded_qconv_module.load_state_dict(loaded_dict) self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module)) @@ -414,7 +414,7 @@ def test_conv1d_api( self._test_conv_api_impl( module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, - out_channels_per_group, groups, kernel_size, stride, pad, + out_channels_per_group, groups, kernel_size, stride, pad, pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise) @@ -479,8 +479,8 @@ def test_conv2d_api( module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, - dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, - Y_zero_point, use_bias, use_fused, use_channelwise) + pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, + Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise) @skipIfNoFBGEMM @given(batch_size=st.integers(1, 3), @@ -499,7 +499,7 @@ def test_conv2d_api( pad_d=st.integers(0, 1), pad_h=st.integers(0, 1), pad_w=st.integers(0, 1), - pad_mode=st.sampled_from(['zeros', 'reflect']), + pad_mode=st.just('zeros'), # Conv3d doesn't support reflection dilation=st.integers(1, 2), X_scale=st.floats(1.2, 1.6), X_zero_point=st.integers(0, 4), @@ -549,8 +549,9 @@ def test_conv3d_api( module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, - dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, - Y_zero_point, use_bias, use_fused, use_channelwise) + pad_mode, dilation, X_scale, X_zero_point, W_scale, + W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, + use_channelwise) def test_pool_api(self): """Tests the correctness of the pool module. diff --git a/torch/nn/intrinsic/quantized/modules/conv_relu.py b/torch/nn/intrinsic/quantized/modules/conv_relu.py index 8dd931ff05a8..125c9da3f405 100644 --- a/torch/nn/intrinsic/quantized/modules/conv_relu.py +++ b/torch/nn/intrinsic/quantized/modules/conv_relu.py @@ -2,10 +2,13 @@ import torch import torch.nn.intrinsic import torch.nn.intrinsic.qat +import torch.nn.functional as F import torch.nn.quantized as nnq from torch.nn.utils import fuse_conv_bn_weights +_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding + class ConvReLU1d(nnq.Conv1d): r""" A ConvReLU1d module is a fused module of Conv1d and ReLU @@ -31,6 +34,11 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 3: raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != 'zeros': + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return torch.ops.quantized.conv1d_relu( input, self._packed_params, self.scale, self.zero_point) @@ -70,6 +78,10 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return torch.ops.quantized.conv2d_relu( input, self._packed_params, self.scale, self.zero_point) @@ -99,6 +111,7 @@ class ConvReLU3d(nnq.Conv3d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): + assert padding_mode != 'reflect', "Conv3d does not support reflection padding" super(ConvReLU3d, self).__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, @@ -109,6 +122,10 @@ def forward(self, input): # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 5: raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) return torch.ops.quantized.conv3d_relu( input, self._packed_params, self.scale, self.zero_point) diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 5a4b30a13ce4..ba700711a482 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -286,7 +286,7 @@ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: w, b, self.stride, self.padding, self.dilation, self.groups) else: self._packed_params = torch.ops.quantized.conv1d_prepack( - w, b, self.stride, _pair_from_first(0), self.dilation, + w, b, self.stride, _pair(0), self.dilation, self.groups) def _weight_bias(self): @@ -305,7 +305,8 @@ def forward(self, input): if len(input.shape) != 3: raise ValueError("Input shape must be `(N, C, L)`!") if self.padding_mode != 'zeros': - _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) input = F.pad(input, _reversed_padding_repeated_twice, mode=self.padding_mode) return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point) @@ -458,6 +459,7 @@ class Conv3d(_ConvNd): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): + assert padding_mode != 'reflect', "Conv3d does not support reflection padding" kernel_size = _triple(kernel_size) stride = _triple(stride) padding = _triple(padding)