diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index e9acd9673ddb0..a3e681e8838c2 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -572,6 +573,110 @@ at::Tensor conv3d( false, {{0, 0, 0}}, groups); } + +static Tensor convolution_same( + const Tensor &input, const Tensor &weight, const Tensor &bias, + IntArrayRef stride, IntArrayRef dilation, int64_t groups) { + + auto k = weight.dim(); + auto dim = k - 2; + TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); + auto weight_sizes = weight.sizes(); + auto input_sizes = input.sizes(); + TORCH_CHECK(k == input.dim(), + "Expected ", k, "-dimensional input for ", + k, "-dimensional weight", weight_sizes, ", but got ", + input.dim(), "-dimensional input of size ", + input.sizes(), " instead"); + TORCH_CHECK(stride.size() == dim || stride.size() == 1, + "stride cannot broadcast to ", dim, " dimensions"); + TORCH_CHECK(dilation.size() == dim || dilation.size() == 1, + "dilation cannot broadcast to ", dim, " dimensions"); + for (int64_t i = 0; i < stride.size(); ++i) { + TORCH_CHECK(stride[i] == 1, "padding='same' is not supported for strided convolutions"); + } + + // Calculate the correct padding + DimVector padding_l, padding_r; + bool symmetric_padding = true; + for (int64_t i = 0; i < dim; ++i) { + auto s = stride.size() == 1 ? stride[0] : stride[i]; + auto d = dilation.size() == 1 ? dilation[0] : dilation[i]; + auto pad = pooling_same_mode_padding_lr( + input_sizes[i + 2], weight_sizes[i + 2], s, d); + padding_l.push_back(pad.first); + padding_r.push_back(pad.second); + if (pad.first != pad.second) { + symmetric_padding = false; + } + } + + if (symmetric_padding) { + // All backends handle symmetric padding natively + DimVector output_padding(static_cast(dim)); + return native::convolution(input, weight, bias, stride, padding_l, dilation, + false, output_padding, groups); + } + + TORCH_WARN_ONCE("Using padding='same' with even kernel lengths and odd dilation may" + " require a zero-padded copy of the input be created"); + SmallVector pad_nd(static_cast(2 * dim)); + for (int i = 0; i < dim; ++i) { + // Apply padding by the difference, leaving only a symmetric padding + auto delta_pad = padding_r[i] - padding_l[i]; + auto pad_idx = 2 * (dim - 1 - i); // F.pad goes from last dim to first + if (delta_pad > 0) { + pad_nd[pad_idx + 1] = delta_pad; + } else { + pad_nd[pad_idx] = delta_pad; + padding_l[i] = padding_r[i]; + } + } + auto padded_input = at::constant_pad_nd(input, pad_nd, 0); + DimVector output_padding(static_cast(dim)); + return at::convolution(padded_input, weight, bias, stride, padding_l, + dilation, false, output_padding, groups); +} + +Tensor _convolution_mode( + const Tensor& input, const Tensor& weight, const Tensor& bias, + IntArrayRef stride, std::string padding, IntArrayRef dilation, + int64_t groups) { + if (padding == "same") { + return at::native::convolution_same( + input, weight, bias, stride, dilation, groups); + } else if (padding == "valid") { + const int64_t padding_[] = {0}; + return at::native::convolution( + input, weight, bias, stride, padding_, dilation, false, padding_, groups); + } + TORCH_CHECK(false, "Invalid padding string: '", padding, "'"); +} + +at::Tensor conv1d( + const Tensor& input, const Tensor& weight, const c10::optional& bias, + IntArrayRef stride, std::string padding, IntArrayRef dilation, + int64_t groups) { + return at::_convolution_mode( + input, weight, bias, stride, std::move(padding), dilation, groups); +} + +at::Tensor conv2d( + const Tensor& input, const Tensor& weight, const c10::optional& bias, + IntArrayRef stride, std::string padding, IntArrayRef dilation, + int64_t groups) { + return at::_convolution_mode( + input, weight, bias, stride, std::move(padding), dilation, groups); +} + +at::Tensor conv3d( + const Tensor& input, const Tensor& weight, const c10::optional& bias, + IntArrayRef stride, std::string padding, IntArrayRef dilation, + int64_t groups) { + return at::_convolution_mode( + input, weight, bias, stride, std::move(padding), dilation, groups); +} + at::Tensor conv_transpose1d( const Tensor& input, const Tensor& weight, const Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) { diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 3e06bb8fa4d88..686d8e7f3b666 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -46,6 +46,24 @@ static inline T pooling_output_shape( inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode); } +inline std::pair pooling_same_mode_padding_lr( + int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) { + // NOTE: with strides, the output shape is ceil(inputSize/stride) + auto total_padding = dilation * (kernelSize - 1); + + // Prefer symmetric padding if possible + if (stride > 2 && (total_padding % 2 == 1)) { + // The floor in the output size calculation gives us a little wiggle room + auto wiggle_room = inputSize % stride - 1; + if (wiggle_room > 0) { + --total_padding; + } + } + + auto left = total_padding / 2; + return {left, total_padding - left}; +} + // AveragePool2d/DilatedMaxPool2d (forward) static inline void diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4bafd4ca43972..9bd07087f4dc3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1025,6 +1025,9 @@ - func: _convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor use_c10_dispatcher: hacky_wrapper_for_legacy_signatures +- func: _convolution_mode(Tensor input, Tensor weight, Tensor? bias, int[] stride, str padding, int[] dilation, int groups) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + - func: _convolution_nogroup(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding) -> Tensor use_c10_dispatcher: hacky_wrapper_for_legacy_signatures @@ -1040,6 +1043,15 @@ - func: conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor use_c10_dispatcher: hacky_wrapper_for_legacy_signatures +- func: conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, str padding="valid", int[1] dilation=1, int groups=1) -> Tensor + cpp_no_default_args: ['bias', 'stride', 'padding'] + +- func: conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, str padding="valid", int[2] dilation=1, int groups=1) -> Tensor + cpp_no_default_args: ['bias', 'stride', 'padding'] + +- func: conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, str padding="valid", int[3] dilation=1, int groups=1) -> Tensor + cpp_no_default_args: ['bias', 'stride', 'padding'] + - func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor dispatch: DefaultBackend: conv_tbc diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index 09b673d588fd4..da2cecf9ea178 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -25,6 +25,27 @@ std::tuple std_mean(const Tensor& self, int dim) { return at::std_mean(self, IntArrayRef{dim}); } +at::Tensor conv1d( + const Tensor& input, const Tensor& weight, const Tensor& bias, IntArrayRef stride, + std::initializer_list padding_, IntArrayRef dilation, int64_t groups) { + auto padding = IntArrayRef(padding_); + return at::conv1d(input, weight, bias, stride, padding, dilation, groups); +} + +at::Tensor conv2d( + const Tensor& input, const Tensor& weight, const Tensor& bias, IntArrayRef stride, + std::initializer_list padding_, IntArrayRef dilation, int64_t groups) { + auto padding = IntArrayRef(padding_); + return at::conv2d(input, weight, bias, stride, padding, dilation, groups); +} + +at::Tensor conv3d( + const Tensor& input, const Tensor& weight, const Tensor& bias, IntArrayRef stride, + std::initializer_list padding_, IntArrayRef dilation, int64_t groups) { + auto padding = IntArrayRef(padding_); + return at::conv3d(input, weight, bias, stride, padding, dilation, groups); +} + ${function_definitions} } diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h index 0c82fdf11afde..5d8f263eac453 100644 --- a/aten/src/ATen/templates/Functions.h +++ b/aten/src/ATen/templates/Functions.h @@ -51,6 +51,19 @@ TORCH_API std::tuple var_mean(const Tensor& self, int dim); TORCH_API Tensor std(const Tensor& self, int dim); TORCH_API std::tuple std_mean(const Tensor& self, int dim); + +// Special C++ only overloads for convnd functions (See gh-45667) +// These are needed because {1, 2} is ambiguous between string and IntArrayRef overloads +TORCH_API at::Tensor conv1d( + const Tensor& input, const Tensor& weight, const Tensor& bias, IntArrayRef stride, + std::initializer_list padding, IntArrayRef dilation = 1, int64_t groups = 1); +TORCH_API at::Tensor conv2d( + const Tensor& input, const Tensor& weight, const Tensor& bias, IntArrayRef stride, + std::initializer_list padding, IntArrayRef dilation = 1, int64_t groups = 1); +TORCH_API at::Tensor conv3d( + const Tensor& input, const Tensor& weight, const Tensor& bias, IntArrayRef stride, + std::initializer_list padding, IntArrayRef dilation = 1, int64_t groups = 1); + namespace { inline std::vector zero_sizes(const TensorOptions& options) { if (options.has_memory_format()) { diff --git a/c10/util/overloaded.h b/c10/util/overloaded.h new file mode 100644 index 0000000000000..fb6bf19ec233b --- /dev/null +++ b/c10/util/overloaded.h @@ -0,0 +1,30 @@ +#pragma once + +namespace c10 { +namespace detail { + +template +struct overloaded_t {}; + +template +struct overloaded_t:T0 { + using T0::operator(); + overloaded_t(T0 t0):T0(std::move(t0)) {} +}; +template +struct overloaded_t:T0, overloaded_t { + using T0::operator(); + using overloaded_t::operator(); + overloaded_t(T0 t0, Ts... ts): + T0(std::move(t0)), + overloaded_t(std::move(ts)...) + {} +}; + +} // namespace detail + +// Construct an overloaded callable combining multiple callables, e.g. lambdas +template +detail::overloaded_t overloaded(Ts...ts){ return {std::move(ts)...}; } + +} // namespace c10 diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index c9027cb77ffc7..fc2fd9d5563fa 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -55,6 +55,15 @@ TEST_F(ModulesTest, Conv1d) { ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3); } +TEST_F(ModulesTest, Conv1dSameStrided) { + auto options = Conv1dOptions(3, 2, 3); + options.stride(1).padding(torch::kSame); + Conv1d model_valid(options); + ASSERT_THROWS_WITH( + [&]{ Conv1d model_invalid(options.stride(2)); }(), + "padding='same' is not supported for strided convolutions"); +} + TEST_F(ModulesTest, Conv2dEven) { Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false)); model->weight.set_data(torch::arange(54, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3})); @@ -95,6 +104,18 @@ TEST_F(ModulesTest, Conv2dUneven) { ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2); } +TEST_F(ModulesTest, Conv2dSameStrided) { + auto options = Conv2dOptions(3, 2, {3, 4}); + options.stride(1).padding(torch::kSame); + Conv2d model_valid(options); + ASSERT_THROWS_WITH( + [&]{ Conv2d model_invalid(options.stride(2)); }(), + "padding='same' is not supported for strided convolutions"); + ASSERT_THROWS_WITH( + [&]{ Conv2d model_invalid(options.stride({1, 2})); }(), + "padding='same' is not supported for strided convolutions"); +} + TEST_F(ModulesTest, Conv3d) { Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false)); model->weight.set_data(torch::arange(162, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3, 3})); @@ -131,6 +152,18 @@ TEST_F(ModulesTest, Conv3d) { ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3); } +TEST_F(ModulesTest, Conv3dSameStrided) { + auto options = Conv3dOptions(3, 2, {3, 4, 5}); + options.stride(1).padding(torch::kSame); + Conv3d model_valid(options); + ASSERT_THROWS_WITH( + [&]{ Conv3d model_invalid(options.stride(2)); }(), + "padding='same' is not supported for strided convolutions"); + ASSERT_THROWS_WITH( + [&]{ Conv3d model_invalid(options.stride({1, 2, 1})); }(), + "padding='same' is not supported for strided convolutions"); +} + TEST_F(ModulesTest, ConvTranspose1d) { ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)); model->weight.set_data(torch::arange(18.).view({2, 3, 3})); diff --git a/test/test_nn.py b/test/test_nn.py index 4c9fd662ee684..dfeeac1f54193 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -888,6 +888,108 @@ def test_invalid_conv3d(self): with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): module(input) + def test_Conv1d_module_same_padding(self): + # Compare module against functional: without strides/dilation, asymmetric padding + x = torch.rand(1, 1, 20) + module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, + padding='same') + expect = F.conv1d(x, module.weight, module.bias, padding='same') + self.assertEqual(expect, module(x)) + + # Test dilation, symmetric padding + module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, + padding='same', dilation=2) + expect = F.conv1d(x, module.weight, module.bias, padding='same', dilation=2) + self.assertEqual(expect, module(x)) + + # Test non-zero padding_mode, requiring explicit padding + module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, + padding='same', padding_mode='replicate') + x_padded = F.pad(x, [4, 5], mode='replicate') + expect = F.conv1d(x_padded, module.weight, module.bias, padding='valid') + self.assertEqual(expect, module(x)) + self.assertEqual(x.size(), expect.size()) + + # Test connstruction with invalid padding string raises + with self.assertRaisesRegex(ValueError, 'Invalid padding string'): + module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') + + # Test connstruction with same padding and strides raises + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) + + def test_Conv2d_module_same_padding(self): + # Compare module against functional: + # without strides/dilation, both symmetric and asymmetric padding + x = torch.rand(1, 1, 9, 20) + module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(5, 10), + padding='same') + expect = F.conv2d(x, module.weight, module.bias, padding='same') + self.assertEqual(expect, module(x)) + + # with dilation, symmetric padding + module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4), + padding='same', dilation=(1, 2)) + expect = F.conv2d(x, module.weight, module.bias, padding='same', dilation=(1, 2)) + self.assertEqual(expect, module(x)) + + # Test non-zero padding_mode, requiring explicit padding + module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4), + padding='same', padding_mode='reflect') + x_padded = F.pad(x, [1, 2, 1, 1], mode='reflect') + expect = F.conv2d(x_padded, module.weight, module.bias, padding='valid') + self.assertEqual(expect, module(x)) + self.assertEqual(x.size(), expect.size()) + + # Test connstruction with invalid padding string raises + with self.assertRaisesRegex(ValueError, 'Invalid padding string'): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') + + # Test connstruction with same padding and strides raises + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 3)) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(4, 1)) + + def test_Conv3d_module_same_padding(self): + # Compare module against functional: + x = torch.rand(1, 1, 4, 4, 4) + # without dilation, both symmetric and asymmetric padding + module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), + padding='same') + expect = F.conv3d(x, module.weight, module.bias, padding='same') + self.assertEqual(expect, module(x)) + + # with dilation, both symmetric and asymmetric padding + module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), + padding='same', dilation=(3, 2, 1)) + expect = F.conv3d(x, module.weight, module.bias, padding='same', dilation=(3, 2, 1)) + self.assertEqual(expect, module(x)) + + # Test non-zero padding_mode, requiring explicit padding + module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), + padding='same', padding_mode='circular') + x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode='circular') + expect = F.conv3d(x_padded, module.weight, module.bias, padding='valid') + self.assertEqual(expect, module(x)) + self.assertEqual(x.size(), expect.size()) + + # Test connstruction with invalid padding string raises + with self.assertRaisesRegex(ValueError, 'Invalid padding string'): + module = nn.Conv3d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') + + # Test connstruction with same padding and strides raises + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 1, 3)) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 4, 1)) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(5, 1, 1)) + def _test_alpha_dropout(self, cls, input): mean = input.mean() std = input.std() @@ -11407,6 +11509,235 @@ def test_affine_3d_rotateRandom(self, device): self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) + def test_conv1d_same_padding(self, device): + # Test padding='same' outputs the correct shape + test_args = [ + # in_size + range(50, 55), + # kernel_size + [1, 2, 3, 8], + # dilation + range(1, 4), + # stride + [1], + ] + for in_size, k_size, dilation, stride in itertools.product(*test_args): + x = torch.rand(1, 1, in_size, device=device) + y = torch.rand(1, 1, k_size, device=device) + z = F.conv1d(x, y, padding='same', dilation=dilation, stride=stride) + self.assertEqual(z.size(2), int(math.ceil(in_size / stride))) + + # Compare F.conv1d padding='same' output against manual padding + # Without strides/dilation + x = torch.rand(1, 1, 12, device=device) + y = torch.rand(1, 1, 3, device=device) + expect = F.conv1d(x, y, padding=1) + actual = F.conv1d(x, y, padding='same') + self.assertEqual(expect, actual) + + # With dilation + x = torch.rand(1, 1, 12, device=device) + y = torch.rand(1, 1, 4, device=device) + expect = F.conv1d(x, y, padding=3, dilation=2) + actual = F.conv1d(x, y, padding='same', dilation=2) + self.assertEqual(expect, actual) + + # Dilation with asymmetric padding + expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:] + actual = F.conv1d(x, y, padding='same', dilation=3) + self.assertEqual(expect, actual) + + + def test_conv2d_same_padding(self, device): + # Compare F.conv2d padding='same' output against manual padding + # Without strides/dilation + x = torch.rand(1, 1, 10, 11, device=device) + y = torch.rand(1, 1, 4, 5, device=device) + expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :] + actual = F.conv2d(x, y, padding='same') + self.assertEqual(expect, actual) + + # With dilation + y = torch.rand(1, 1, 3, 4, device=device) + expect = F.conv2d(x, y, padding=(2, 3), dilation=2) + actual = F.conv2d(x, y, padding='same', dilation=2) + self.assertEqual(expect, actual) + + # Dilation with asymmetric padding + y = torch.rand(1, 1, 4, 4, device=device) + expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:] + actual = F.conv2d(x, y, padding='same', dilation=3) + self.assertEqual(expect, actual) + + def test_conv3d_same_padding(self, device): + # Compare F.conv3d padding='same' output against manual padding + # Without strides/dilation + x = torch.rand(1, 1, 10, 11, 12, device=device) + y = torch.rand(1, 1, 1, 2, 5, device=device) + expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :] + actual = F.conv3d(x, y, padding='same') + self.assertEqual(expect, actual) + + # With dilation + expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) + actual = F.conv3d(x, y, padding='same', dilation=2) + self.assertEqual(expect, actual) + + # Dilation with asymmetric padding + y = torch.rand(1, 1, 4, 4, 4, device=device) + expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:] + actual = F.conv3d(x, y, padding='same', dilation=3) + self.assertEqual(expect, actual) + + def test_conv1d_valid_padding(self, device): + # Test F.conv1d padding='valid' is the same as no padding + x = torch.rand(1, 1, 10, device=device) + y = torch.rand(1, 1, 4, device=device) + expect = F.conv1d(x, y) + actual = F.conv1d(x, y, padding='valid') + self.assertEqual(expect, actual) + + def test_conv2d_valid_padding(self, device): + # Test F.conv2d padding='valid' is the same as no padding + x = torch.rand(1, 1, 1, 10, device=device) + y = torch.rand(1, 1, 1, 4, device=device) + expect = F.conv2d(x, y) + actual = F.conv2d(x, y, padding='valid') + self.assertEqual(expect, actual) + + def test_conv3d_valid_padding(self, device): + # Test F.conv3d padding='valid' is the same as no padding + x = torch.rand(1, 1, 1, 1, 10, device=device) + y = torch.rand(1, 1, 1, 1, 4, device=device) + expect = F.conv3d(x, y) + actual = F.conv3d(x, y, padding='valid') + self.assertEqual(expect, actual) + + def test_conv1d_same_padding_backward(self, device): + # Test F.conv1d gradients work with padding='same' + x = torch.rand(1, 1, 12, device=device, requires_grad=True) + y = torch.rand(1, 1, 4, device=device, requires_grad=True) + + # Symmetric padding + z = F.conv1d(x, y, padding=3, dilation=2) + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv1d(x, y, padding='same', dilation=2) + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + x.grad, y.grad = None, None + + # Asymmetric padding + z = F.conv1d(x, y, padding=2)[..., 1:] + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv1d(x, y, padding='same') + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + + def test_conv2d_same_padding_backward(self, device): + # Test F.conv2d gradients work with padding='same' + x = torch.rand(1, 1, 10, 11, device=device, requires_grad=True) + y = torch.rand(1, 1, 4, 5, device=device, requires_grad=True) + + # Symmetric padding + z = F.conv2d(x, y, padding=(3, 4), dilation=2) + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv2d(x, y, padding='same', dilation=2) + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + x.grad, y.grad = None, None + + # Asymmetric padding + y = torch.rand(1, 1, 4, 4, device=device, requires_grad=True) + z = F.conv2d(x, y, padding=2)[..., 1:, 1:] + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv1d(x, y, padding='same') + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + + def test_conv3d_same_padding_backward(self, device): + # Test F.conv3d gradients work with padding='same' + x = torch.rand(1, 1, 1, 11, 12, device=device, requires_grad=True) + y = torch.rand(1, 1, 1, 2, 5, device=device, requires_grad=True) + + # Symmetric padding + z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv3d(x, y, padding='same', dilation=2) + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + x.grad, y.grad = None, None + + # Asymmetric padding + y = torch.rand(1, 1, 1, 4, 4, device=device, requires_grad=True) + z = F.conv3d(x, y, padding=2)[..., 1:, 1:] + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv3d(x, y, padding='same') + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + + def test_conv1d_valid_padding_backward(self, device): + # Test F.conv1d gradients work with padding='valid' + x = torch.rand(1, 1, 10, device=device, requires_grad=True) + y = torch.rand(1, 1, 4, device=device, requires_grad=True) + F.conv1d(x, y, padding=0).sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + F.conv1d(x, y, padding='valid').sum().backward() + gx_actual, gy_actual = x.grad, y.grad + self.assertEqual(gx_expect, gx_actual) + self.assertEqual(gy_expect, gy_actual) + + def test_conv2d_valid_padding_backward(self, device): + # Test F.conv2d gradients work with padding='valid' + x = torch.rand(1, 1, 1, 10, device=device, requires_grad=True) + y = torch.rand(1, 1, 1, 4, device=device, requires_grad=True) + F.conv2d(x, y, padding=0).sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + F.conv2d(x, y, padding='valid').sum().backward() + gx_actual, gy_actual = x.grad, y.grad + self.assertEqual(gx_expect, gx_actual) + self.assertEqual(gy_expect, gy_actual) + + def test_conv3d_valid_padding_backward(self, device): + # Test F.conv3d gradients work with padding='valid' + x = torch.rand(1, 1, 1, 1, 10, device=device, requires_grad=True) + y = torch.rand(1, 1, 1, 1, 4, device=device, requires_grad=True) + F.conv3d(x, y, padding=0).sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + F.conv3d(x, y, padding='valid').sum().backward() + gx_actual, gy_actual = x.grad, y.grad + self.assertEqual(gx_expect, gx_actual) + self.assertEqual(gy_expect, gy_actual) + def test_Dropout(self, device): input = torch.empty(1000) self._test_dropout(nn.Dropout, device, input) diff --git a/torch/csrc/api/include/torch/enum.h b/torch/csrc/api/include/torch/enum.h index 098d52bd16f46..7e662fc83b442 100644 --- a/torch/csrc/api/include/torch/enum.h +++ b/torch/csrc/api/include/torch/enum.h @@ -128,6 +128,8 @@ TORCH_ENUM_DECLARE(RNN_TANH) TORCH_ENUM_DECLARE(RNN_RELU) TORCH_ENUM_DECLARE(LSTM) TORCH_ENUM_DECLARE(GRU) +TORCH_ENUM_DECLARE(Valid) +TORCH_ENUM_DECLARE(Same) namespace torch { namespace enumtype { @@ -169,6 +171,8 @@ struct _compute_enum_name { TORCH_ENUM_PRETTY_PRINT(RNN_RELU) TORCH_ENUM_PRETTY_PRINT(LSTM) TORCH_ENUM_PRETTY_PRINT(GRU) + TORCH_ENUM_PRETTY_PRINT(Valid) + TORCH_ENUM_PRETTY_PRINT(Same) }; template diff --git a/torch/csrc/api/include/torch/nn/functional/conv.h b/torch/csrc/api/include/torch/nn/functional/conv.h index b9d6e02c27f8b..60637a8daac54 100644 --- a/torch/csrc/api/include/torch/nn/functional/conv.h +++ b/torch/csrc/api/include/torch/nn/functional/conv.h @@ -9,22 +9,39 @@ namespace functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { + +inline std::string padding_unwrap(enumtype::kValid) { + return "valid"; +} + +inline std::string padding_unwrap(enumtype::kSame) { + return "same"; +} + +template +IntArrayRef padding_unwrap(const ExpandingArray& array) { + return array; +} + + inline Tensor conv1d( const Tensor& input, const Tensor& weight, const Tensor& bias, ExpandingArray<1> stride, - ExpandingArray<1> padding, + const Conv1dFuncOptions::padding_t& padding, ExpandingArray<1> dilation, int64_t groups) { - return torch::conv1d( - input, - weight, - bias, - stride, - padding, - dilation, - groups); + return c10::visit([&](const auto & pad) { + return torch::conv1d( + input, + weight, + bias, + stride, + padding_unwrap(pad), + dilation, + groups); + }, padding); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -61,17 +78,19 @@ inline Tensor conv2d( const Tensor& weight, const Tensor& bias, ExpandingArray<2> stride, - ExpandingArray<2> padding, + const Conv2dFuncOptions::padding_t& padding, ExpandingArray<2> dilation, int64_t groups) { - return torch::conv2d( - input, - weight, - bias, - stride, - padding, - dilation, - groups); + return c10::visit([&](const auto & pad) { + return torch::conv2d( + input, + weight, + bias, + stride, + padding_unwrap(pad), + dilation, + groups); + }, padding); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -108,17 +127,19 @@ inline Tensor conv3d( const Tensor& weight, const Tensor& bias, ExpandingArray<3> stride, - ExpandingArray<3> padding, + const Conv3dFuncOptions::padding_t& padding, ExpandingArray<3> dilation, int64_t groups) { - return torch::conv3d( - input, - weight, - bias, - stride, - padding, - dilation, - groups); + return c10::visit([&](const auto & pad) { + return torch::conv3d( + input, + weight, + bias, + stride, + padding_unwrap(pad), + dilation, + groups); + }, padding); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h index 75a2f1eb96bd9..929fcef9b4e6e 100644 --- a/torch/csrc/api/include/torch/nn/modules/conv.h +++ b/torch/csrc/api/include/torch/nn/modules/conv.h @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -33,7 +35,33 @@ class ConvNdImpl : public torch::nn::Cloneable { options.out_channels() % options.groups() == 0, "out_channels must be divisible by groups"); - _reversed_padding_repeated_twice = torch::nn::modules::utils::_reverse_repeat_vector(options.padding(), 2); + c10::visit(c10::overloaded( + [&](enumtype::kValid) { + _reversed_padding_repeated_twice.resize(2 * D); + std::fill_n(_reversed_padding_repeated_twice.begin(), 2 * D, 0); + }, + [&](enumtype::kSame) { + for (int64_t i = 0; i < D; ++i) { + const auto stride = (*options.stride())[i]; + TORCH_CHECK(stride == 1, "padding='same' is not supported for strided convolutions"); + } + + _reversed_padding_repeated_twice.resize(2 * D); + for (int64_t i = 0; i < D; ++i) { + const auto dilation = (*options.dilation())[i]; + const auto kernel_size = (*options.kernel_size())[i]; + const auto total_padding = dilation * (kernel_size - 1); + auto left_pad = total_padding / 2; + auto right_pad = total_padding - left_pad; + _reversed_padding_repeated_twice[2 * i] = left_pad; + _reversed_padding_repeated_twice[2 * i + 1] = right_pad; + } + }, + [&](const ExpandingArray & pad) { + _reversed_padding_repeated_twice = + torch::nn::modules::utils::_reverse_repeat_vector(pad, 2); + }), + options.padding()); if (options.transposed()) { std::vector weight_sizes = { @@ -80,9 +108,19 @@ class ConvNdImpl : public torch::nn::Cloneable { << ", " << options.out_channels() << ", kernel_size=" << options.kernel_size() << ", stride=" << options.stride(); - if (*options.padding() != *ExpandingArray(0)) { - stream << ", padding=" << options.padding(); - } + c10::visit(c10::overloaded( + [&](enumtype::kValid) { + stream << ", padding='valid'"; + }, + [&](enumtype::kSame) { + stream << ", padding='same'"; + }, + [&](const ExpandingArray & pad) { + if (*pad != *ExpandingArray(0)) { + stream << ", padding=" << pad; + } + }), + options.padding()); if (*options.dilation() != *ExpandingArray(1)) { stream << ", dilation=" << options.dilation(); } @@ -220,6 +258,11 @@ template class ConvTransposeNdImpl : public ConvNdImpl { public: using torch::nn::ConvNdImpl::ConvNdImpl; + explicit ConvTransposeNdImpl(detail::ConvNdOptions options_) + : ConvNdImpl(options_) { + TORCH_INTERNAL_ASSERT(c10::holds_alternative>(this->options.padding()), + "ConvTranspose padding cannot be a string"); + } /// Pretty prints the `ConvTranspose{1,2,3}d` module into the given `stream`. void pretty_print(std::ostream& stream) const override { @@ -228,8 +271,9 @@ class ConvTransposeNdImpl : public ConvNdImpl { << ", " << this->options.out_channels() << ", kernel_size=" << this->options.kernel_size() << ", stride=" << this->options.stride(); - if (*this->options.padding() != *ExpandingArray(0)) { - stream << ", padding=" << this->options.padding(); + const auto & pad = padding(); + if (*pad != *ExpandingArray(0)) { + stream << ", padding=" << pad; } if (*this->options.dilation() != *ExpandingArray(1)) { stream << ", dilation=" << this->options.dilation(); @@ -250,6 +294,10 @@ class ConvTransposeNdImpl : public ConvNdImpl { } protected: + const ExpandingArray& padding() const { + return c10::get>(this->options.padding()); + } + std::vector _output_padding( const Tensor& input, const c10::optional& output_size, const ExpandingArray& stride, const ExpandingArray& padding, diff --git a/torch/csrc/api/include/torch/nn/options/conv.h b/torch/csrc/api/include/torch/nn/options/conv.h index 20e45c5fb1af8..86cdcb7ee71f6 100644 --- a/torch/csrc/api/include/torch/nn/options/conv.h +++ b/torch/csrc/api/include/torch/nn/options/conv.h @@ -18,9 +18,16 @@ typedef c10::variant< enumtype::kCircular > conv_padding_mode_t; +template +using conv_padding_t = c10::variant< + ExpandingArray, + enumtype::kValid, + enumtype::kSame>; + /// Options for a `D`-dimensional convolution or convolution transpose module. template struct ConvNdOptions { + using padding_t = conv_padding_t; ConvNdOptions( int64_t in_channels, int64_t out_channels, @@ -53,7 +60,12 @@ struct ConvNdOptions { /// For a `D`-dim convolution, must be a single number or a list of `D` /// numbers. /// This parameter __can__ be changed after construction. - TORCH_ARG(ExpandingArray, padding) = 0; + TORCH_ARG(padding_t, padding) = 0; + +public: + decltype(auto) padding(std::initializer_list il) { + return padding(IntArrayRef{il}); + } /// The kernel dilation. /// For a `D`-dim convolution, must be a single number or a list of `D` @@ -92,6 +104,7 @@ struct ConvNdOptions { template struct ConvOptions { using padding_mode_t = detail::conv_padding_mode_t; + using padding_t = detail::conv_padding_t; ConvOptions( int64_t in_channels, @@ -125,7 +138,12 @@ struct ConvOptions { /// For a `D`-dim convolution, must be a single number or a list of `D` /// numbers. /// This parameter __can__ be changed after construction. - TORCH_ARG(ExpandingArray, padding) = 0; + TORCH_ARG(padding_t, padding) = 0; + +public: + decltype(auto) padding(std::initializer_list il) { + return padding(IntArrayRef{il}); + } /// The kernel dilation. /// For a `D`-dim convolution, must be a single number or a list of `D` @@ -176,6 +194,8 @@ namespace functional { /// Options for a `D`-dimensional convolution functional. template struct ConvFuncOptions { + using padding_t = torch::nn::detail::conv_padding_t; + /// optional bias of shape `(out_channels)`. Default: ``None`` TORCH_ARG(torch::Tensor, bias) = Tensor(); @@ -187,7 +207,12 @@ struct ConvFuncOptions { /// Implicit paddings on both sides of the input. /// For a `D`-dim convolution, must be a single number or a list of `D` /// numbers. - TORCH_ARG(ExpandingArray, padding) = 0; + TORCH_ARG(padding_t, padding) = 0; + +public: + decltype(auto) padding(std::initializer_list il) { + return padding(IntArrayRef{il}); + } /// The spacing between kernel elements. /// For a `D`-dim convolution, must be a single number or a list of `D` diff --git a/torch/csrc/api/src/enum.cpp b/torch/csrc/api/src/enum.cpp index 336f09ecd5e74..c1a434b35b85d 100644 --- a/torch/csrc/api/src/enum.cpp +++ b/torch/csrc/api/src/enum.cpp @@ -35,3 +35,5 @@ TORCH_ENUM_DEFINE(RNN_TANH) TORCH_ENUM_DEFINE(RNN_RELU) TORCH_ENUM_DEFINE(LSTM) TORCH_ENUM_DEFINE(GRU) +TORCH_ENUM_DEFINE(Valid) +TORCH_ENUM_DEFINE(Same) diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp index 1dcee6025d248..95b3b5e99d244 100644 --- a/torch/csrc/api/src/nn/modules/conv.cpp +++ b/torch/csrc/api/src/nn/modules/conv.cpp @@ -218,11 +218,12 @@ Tensor ConvTranspose1dImpl::forward( TORCH_CHECK(false, "Only `zeros` padding mode is supported for ConvTranspose1d"); } + const auto & pad = padding(); std::vector output_padding = _output_padding( - input, output_size, options.stride(), options.padding(), options.kernel_size()); + input, output_size, options.stride(), pad, options.kernel_size()); return F::detail::conv_transpose1d( - input, weight, bias, options.stride(), options.padding(), + input, weight, bias, options.stride(), pad, output_padding, options.groups(), options.dilation()); } @@ -247,11 +248,12 @@ Tensor ConvTranspose2dImpl::forward( TORCH_CHECK(false, "Only `zeros` padding mode is supported for ConvTranspose2d"); } + const auto & pad = padding(); std::vector output_padding = _output_padding( - input, output_size, options.stride(), options.padding(), options.kernel_size()); + input, output_size, options.stride(), pad, options.kernel_size()); return F::detail::conv_transpose2d( - input, weight, bias, options.stride(), options.padding(), + input, weight, bias, options.stride(), pad, output_padding, options.groups(), options.dilation()); } @@ -276,11 +278,12 @@ Tensor ConvTranspose3dImpl::forward( TORCH_CHECK(false, "Only `zeros` padding mode is supported for ConvTranspose3d"); } + const auto & pad = padding(); std::vector output_padding = _output_padding( - input, output_size, options.stride(), options.padding(), options.kernel_size()); + input, output_size, options.stride(), pad, options.kernel_size()); return F::detail::conv_transpose3d( - input, weight, bias, options.stride(), options.padding(), + input, weight, bias, options.stride(), pad, output_padding, options.groups(), options.dilation()); } diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index 4533c4c576210..83b348630388b 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -70,7 +70,7 @@ void nnc_aten_conv2d( int64_t groups = extra_args[6]; try { - r = at::native::conv2d( + r = at::conv2d( x, w, b, @@ -82,7 +82,7 @@ void nnc_aten_conv2d( } } else { try { - r = at::native::conv2d(x, w); + r = at::conv2d(x, w); } catch (...) { } } diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 16d39274f04a0..92a5dcbbe06dc 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -45,8 +45,16 @@ bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None`` stride: the stride of the convolving kernel. Can be a single number or a one-element tuple `(sW,)`. Default: 1 - padding: implicit paddings on both sides of the input. Can be a + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, single number or a one-element tuple `(padW,)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. dilation: the spacing between kernel elements. Can be a single number or a one-element tuple `(dW,)`. Default: 1 groups: split input into groups, :math:`\text{in\_channels}` should be divisible by @@ -78,14 +86,24 @@ **reproducibility_notes, **tf32_notes ) + r""" + Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None`` stride: the stride of the convolving kernel. Can be a single number or a tuple `(sH, sW)`. Default: 1 - padding: implicit paddings on both sides of the input. Can be a + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, single number or a tuple `(padH, padW)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + dilation: the spacing between kernel elements. Can be a single number or a tuple `(dH, dW)`. Default: 1 groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the @@ -125,8 +143,17 @@ bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None stride: the stride of the convolving kernel. Can be a single number or a tuple `(sT, sH, sW)`. Default: 1 - padding: implicit paddings on both sides of the input. Can be a + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, single number or a tuple `(padT, padH, padW)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + dilation: the spacing between kernel elements. Can be a single number or a tuple `(dT, dH, dW)`. Default: 1 groups: split input into groups, :math:`\text{in\_channels}` should be divisible by diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 748df94828a83..16bc4a492fa0f 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -13,7 +13,7 @@ from torch._torch_docs import reproducibility_notes from ..common_types import _size_1_t, _size_2_t, _size_3_t -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, Union convolution_notes = \ {"groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs. @@ -51,10 +51,11 @@ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) - ... _in_channels: int + _reversed_padding_repeated_twice: List[int] out_channels: int kernel_size: Tuple[int, ...] stride: Tuple[int, ...] - padding: Tuple[int, ...] + padding: Union[str, Tuple[int, ...]] dilation: Tuple[int, ...] transposed: bool output_padding: Tuple[int, ...] @@ -80,6 +81,15 @@ def __init__(self, raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') + valid_padding_strings = {'same', 'valid'} + if isinstance(padding, str): + if padding not in valid_padding_strings: + raise ValueError( + "Invalid padding string {!r}, should be one of {}".format( + padding, valid_padding_strings)) + if padding == 'same' and any(s != 1 for s in stride): + raise ValueError("padding='same' is not supported for strided convolutions") + valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} if padding_mode not in valid_padding_modes: raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format( @@ -98,7 +108,19 @@ def __init__(self, # `F.pad` if needed (e.g., for non-zero padding types that are # implemented as two ops: padding + conv). `F.pad` accepts paddings in # reverse order than the dimension. - self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) + if padding == 'same': + for d, k, i in zip(dilation, kernel_size, + range(len(kernel_size) - 1, -1, -1)): + total_padding = d * (k - 1) + left_pad = total_padding // 2 + self._reversed_padding_repeated_twice[2 * i] = left_pad + self._reversed_padding_repeated_twice[2 * i + 1] = ( + total_padding - left_pad) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) + if transposed: self.weight = Parameter(torch.Tensor( in_channels, out_channels // groups, *kernel_size)) @@ -164,8 +186,9 @@ class Conv1d(_ConvNd): * :attr:`stride` controls the stride for the cross-correlation, a single number or a one-element tuple. - * :attr:`padding` controls the amount of implicit padding on both sides - for :attr:`padding` number of points. + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ @@ -178,12 +201,17 @@ class Conv1d(_ConvNd): Note: {cudnn_reproducibility_note} + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 - padding (int or tuple, optional): Zero-padding added to both sides of + padding (int, tuple or str, optional): Padding added to both sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` @@ -235,7 +263,7 @@ def __init__( out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, - padding: _size_1_t = 0, + padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, @@ -245,7 +273,7 @@ def __init__( # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] kernel_size_ = _single(kernel_size) stride_ = _single(stride) - padding_ = _single(padding) + padding_ = padding if isinstance(padding, str) else _single(padding) dilation_ = _single(dilation) super(Conv1d, self).__init__( in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, @@ -287,8 +315,9 @@ class Conv2d(_ConvNd): * :attr:`stride` controls the stride for the cross-correlation, a single number or a tuple. - * :attr:`padding` controls the amount of implicit padding on both - sides for :attr:`padding` number of points for each dimension. + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ @@ -308,12 +337,17 @@ class Conv2d(_ConvNd): Note: {cudnn_reproducibility_note} + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 - padding (int or tuple, optional): Zero-padding added to both sides of + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` @@ -373,7 +407,7 @@ def __init__( out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, - padding: _size_2_t = 0, + padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, @@ -381,7 +415,7 @@ def __init__( ): kernel_size_ = _pair(kernel_size) stride_ = _pair(stride) - padding_ = _pair(padding) + padding_ = padding if isinstance(padding, str) else _pair(padding) dilation_ = _pair(dilation) super(Conv2d, self).__init__( in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, @@ -416,8 +450,9 @@ class Conv3d(_ConvNd): * :attr:`stride` controls the stride for the cross-correlation. - * :attr:`padding` controls the amount of implicit padding on both - sides for :attr:`padding` number of points for each dimension. + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. @@ -436,12 +471,18 @@ class Conv3d(_ConvNd): Note: {cudnn_reproducibility_note} + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 - padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + padding (int, tuple or str, optional): Padding added to all six sides of + the input. Default: 0 padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 @@ -498,7 +539,7 @@ def __init__( out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, - padding: _size_3_t = 0, + padding: Union[str, _size_3_t] = 0, dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, @@ -506,7 +547,7 @@ def __init__( ): kernel_size_ = _triple(kernel_size) stride_ = _triple(stride) - padding_ = _triple(padding) + padding_ = padding if isinstance(padding, str) else _triple(padding) dilation_ = _triple(dilation) super(Conv3d, self).__init__( in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, @@ -701,6 +742,7 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') + assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. output_padding = self._output_padding( @@ -845,6 +887,7 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. output_padding = self._output_padding( @@ -986,6 +1029,7 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d') + assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. output_padding = self._output_padding( diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 0b6616b905b3f..d6cef76dacc68 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -1780,6 +1780,42 @@ def fractional_max_pool3d_test(test_case): with_tf32=True, tf32_precision=0.005, ), + dict( + fullname='Conv1d_pad_valid', + constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), + dict( + fullname='Conv1d_pad_same', + constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), + dict( + fullname='Conv1d_pad_same2', + constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), + dict( + fullname='Conv1d_pad_same_dilated', + constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2), + cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)', + input_size=(2, 4, 10), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), dict( fullname='ConvTranspose1d', constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)), @@ -1914,6 +1950,33 @@ def fractional_max_pool3d_test(test_case): check_with_long_tensor=True, with_tf32=True, ), + dict( + fullname='Conv2d_pad_valid', + constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"), + cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)', + input_size=(2, 2, 6, 5), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), + dict( + fullname='Conv2d_pad_same', + constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"), + cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)', + input_size=(2, 2, 6, 5), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), + dict( + fullname='Conv2d_pad_same_dilated', + constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2), + cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)', + input_size=(2, 2, 6, 5), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), dict( module_name='ConvTranspose2d', constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)), @@ -2373,6 +2436,33 @@ def fractional_max_pool3d_test(test_case): with_tf32=True, tf32_precision=0.05 ), + dict( + fullname='Conv3d_pad_valid', + constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)', + input_size=(2, 3, 6, 5, 4), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), + dict( + fullname='Conv3d_pad_same', + constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)', + input_size=(2, 3, 6, 5, 4), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), + dict( + fullname='Conv3d_pad_same_dilated', + constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2), + cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)', + input_size=(2, 3, 6, 5, 4), + cudnn=True, + with_tf32=True, + tf32_precision=0.005, + ), dict( module_name='ConvTranspose3d', constructor_args=(2, 3, (2, 3, 2)),