Skip to content

Commit

Permalink
Update on "[quant] Add reflection padding to conv"
Browse files Browse the repository at this point in the history
Differential Revision: [D25394384](https://our.internmc.facebook.com/intern/diff/D25394384)

[ghstack-poisoned]
  • Loading branch information
z-a-f committed Feb 4, 2021
1 parent 0ce292f commit b47850f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
21 changes: 11 additions & 10 deletions test/quantization/test_quantized_module.py
Expand Up @@ -230,9 +230,9 @@ def test_quant_dequant_api(self):
def _test_conv_api_impl(
self, module_name, qconv_module, conv_module, batch_size,
in_channels_per_group, input_feature_map_size, out_channels_per_group,
groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point,
W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused,
use_channelwise,
groups, kernel_size, stride, padding, padding_mode, dilation,
X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
use_bias, use_fused, use_channelwise,
):
for i in range(len(kernel_size)):
assume(input_feature_map_size[i] + 2 * padding[i]
Expand Down Expand Up @@ -304,7 +304,7 @@ def _test_conv_api_impl(
self.assertEqual(model_dict[key], loaded_dict[key])
loaded_qconv_module = type(qconv_module)(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, use_bias, padding_mode="zeros")
groups, use_bias, padding_mode=padding_mode)
loaded_qconv_module.load_state_dict(loaded_dict)

self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module))
Expand Down Expand Up @@ -414,7 +414,7 @@ def test_conv1d_api(
self._test_conv_api_impl(
module_name, qconv_module, conv_module, batch_size,
in_channels_per_group, input_feature_map_size,
out_channels_per_group, groups, kernel_size, stride, pad,
out_channels_per_group, groups, kernel_size, stride, pad, pad_mode,
dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
Y_zero_point, use_bias, use_fused, use_channelwise)

Expand Down Expand Up @@ -479,8 +479,8 @@ def test_conv2d_api(
module_name, qconv_module, conv_module, batch_size,
in_channels_per_group, input_feature_map_size,
out_channels_per_group, groups, kernel_size, stride, padding,
dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
Y_zero_point, use_bias, use_fused, use_channelwise)
pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise)

@skipIfNoFBGEMM
@given(batch_size=st.integers(1, 3),
Expand All @@ -499,7 +499,7 @@ def test_conv2d_api(
pad_d=st.integers(0, 1),
pad_h=st.integers(0, 1),
pad_w=st.integers(0, 1),
pad_mode=st.sampled_from(['zeros', 'reflect']),
pad_mode=st.just('zeros'), # Conv3d doesn't support reflection
dilation=st.integers(1, 2),
X_scale=st.floats(1.2, 1.6),
X_zero_point=st.integers(0, 4),
Expand Down Expand Up @@ -549,8 +549,9 @@ def test_conv3d_api(
module_name, qconv_module, conv_module, batch_size,
in_channels_per_group, input_feature_map_size,
out_channels_per_group, groups, kernel_size, stride, padding,
dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
Y_zero_point, use_bias, use_fused, use_channelwise)
pad_mode, dilation, X_scale, X_zero_point, W_scale,
W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused,
use_channelwise)

def test_pool_api(self):
"""Tests the correctness of the pool module.
Expand Down
17 changes: 17 additions & 0 deletions torch/nn/intrinsic/quantized/modules/conv_relu.py
Expand Up @@ -2,10 +2,13 @@
import torch
import torch.nn.intrinsic
import torch.nn.intrinsic.qat
import torch.nn.functional as F
import torch.nn.quantized as nnq

from torch.nn.utils import fuse_conv_bn_weights

_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding

class ConvReLU1d(nnq.Conv1d):
r"""
A ConvReLU1d module is a fused module of Conv1d and ReLU
Expand All @@ -31,6 +34,11 @@ def forward(self, input):
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != 'zeros':
# Padding in Conv1d is stored as (p, p), need to get (p,)
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return torch.ops.quantized.conv1d_relu(
input, self._packed_params, self.scale, self.zero_point)

Expand Down Expand Up @@ -70,6 +78,10 @@ def forward(self, input):
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return torch.ops.quantized.conv2d_relu(
input, self._packed_params, self.scale, self.zero_point)

Expand Down Expand Up @@ -99,6 +111,7 @@ class ConvReLU3d(nnq.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros'):
assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
super(ConvReLU3d, self).__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias,
Expand All @@ -109,6 +122,10 @@ def forward(self, input):
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return torch.ops.quantized.conv3d_relu(
input, self._packed_params, self.scale, self.zero_point)

Expand Down
6 changes: 4 additions & 2 deletions torch/nn/quantized/modules/conv.py
Expand Up @@ -286,7 +286,7 @@ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
w, b, self.stride, self.padding, self.dilation, self.groups)
else:
self._packed_params = torch.ops.quantized.conv1d_prepack(
w, b, self.stride, _pair_from_first(0), self.dilation,
w, b, self.stride, _pair(0), self.dilation,
self.groups)

def _weight_bias(self):
Expand All @@ -305,7 +305,8 @@ def forward(self, input):
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
# Padding in Conv1d is stored as (p, p), need to get (p,)
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
Expand Down Expand Up @@ -458,6 +459,7 @@ class Conv3d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros'):
assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
Expand Down

0 comments on commit b47850f

Please sign in to comment.