Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Enables No-batch for *Pad1d Modules #61060

Closed
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 0 additions & 5 deletions test/test_nn.py
Expand Up @@ -13060,11 +13060,6 @@ def test_ReplicationPad_empty(self, device, dtype):
(torch.nn.ReplicationPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device, dtype=dtype))]:
self._test_module_empty_input(mod, inp, check_size=False)

with self.assertRaisesRegex(NotImplementedError, 'Only 3D'):
mod = torch.nn.ReplicationPad1d(2)
inp = torch.randn(3, 10, device=device, dtype=dtype)
mod(inp)

with self.assertRaisesRegex(RuntimeError, 'Expected 2D or 3D'):
mod = torch.nn.ReplicationPad1d(2)
inp = torch.randn(3, 0, 10, device=device, dtype=dtype)
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/api/include/torch/nn/functional/padding.h
Expand Up @@ -44,8 +44,7 @@ inline Tensor pad(const Tensor& input,
"Padding mode \"",
torch::enumtype::get_enum_name(mode),
"\" doesn't take in value argument");
if (input.dim() == 3) {
TORCH_CHECK(pad.size() == 2, "3D tensors expect 2 values for padding");
if (pad.size() == 2 and (input.dim() == 2 or input.dim() == 3)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a TORCH_CHECK call below in the final else branch that needs a slightly tweaked error message now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasjpfan Whoops, I missed this before, but it looks like some python syntax snuck in here:

  • and -> &&
  • or -> ||

if (c10::get_if<enumtype::kReflect>(&mode)) {
return torch::reflection_pad1d(input, pad);
} else if (c10::get_if<enumtype::kReplicate>(&mode)) {
Expand Down
3 changes: 1 addition & 2 deletions torch/nn/functional.py
Expand Up @@ -4154,8 +4154,7 @@ def _pad(input: Tensor, pad: List[int], mode: str = "constant", value: float = 0
return _VF.constant_pad_nd(input, pad, value)
else:
assert value == 0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode)
if input.dim() == 3:
assert len(pad) == 2, "3D tensors expect 2 values for padding"
if len(pad) == 2 and (input.dim() == 2 or input.dim() == 3):
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
if mode == "reflect":
return torch._C._nn.reflection_pad1d(input, pad)
elif mode == "replicate":
Expand Down
12 changes: 6 additions & 6 deletions torch/nn/modules/padding.py
Expand Up @@ -39,8 +39,8 @@ class ConstantPad1d(_ConstantPadNd):
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)

Shape:
- Input: :math:`(N, C, W_{in})`
- Output: :math:`(N, C, W_{out})` where
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Expand Down Expand Up @@ -189,8 +189,8 @@ class ReflectionPad1d(_ReflectionPadNd):
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)

Shape:
- Input: :math:`(N, C, W_{in})`
- Output: :math:`(N, C, W_{out})` where
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Expand Down Expand Up @@ -350,8 +350,8 @@ class ReplicationPad1d(_ReplicationPadNd):
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)

Shape:
- Input: :math:`(N, C, W_{in})`
- Output: :math:`(N, C, W_{out})` where
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Expand Down
34 changes: 34 additions & 0 deletions torch/testing/_internal/common_nn.py
Expand Up @@ -1247,6 +1247,16 @@ def fractional_max_pool3d_test(test_case):
cpp_var_map={'random_samples': random_samples},
fullname='FractionalMaxPool3d_asymsize')

def single_batch_reference_fn(input, parameters, module):
"""Reference function for modules supporting no batch dimensions.

The module is passed the input and target in batched form with a single item.
The output is squeezed to compare with the no-batch input.
"""
single_batch_input = input.unsqueeze(0)
with freeze_rng_state():
return module(single_batch_input).squeeze(0)
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved


new_module_tests = [
poissonnllloss_no_reduce_test(),
Expand Down Expand Up @@ -2192,6 +2202,14 @@ def fractional_max_pool3d_test(test_case):
cpp_constructor_args='torch::nn::ReflectionPad1dOptions({1, 2})',
input_size=(2, 3, 8),
),
dict(
module_name='ReflectionPad1d',
constructor_args=((1, 2),),
cpp_constructor_args='torch::nn::ReflectionPad1dOptions({1, 2})',
input_size=(3, 8),
reference_fn=single_batch_reference_fn,
desc='batch',
),
dict(
module_name='ReflectionPad1d',
constructor_args=((1, 2),),
Expand Down Expand Up @@ -2234,6 +2252,14 @@ def fractional_max_pool3d_test(test_case):
cpp_constructor_args='torch::nn::ReplicationPad1dOptions({1, 2})',
input_size=(2, 3, 4),
),
dict(
module_name='ReplicationPad1d',
constructor_args=((1, 2),),
cpp_constructor_args='torch::nn::ReplicationPad1dOptions({1, 2})',
input_size=(3, 4),
reference_fn=single_batch_reference_fn,
desc='batch',
),
dict(
module_name='ReplicationPad1d',
constructor_args=((1, 2),),
Expand Down Expand Up @@ -2283,6 +2309,14 @@ def fractional_max_pool3d_test(test_case):
cpp_constructor_args='torch::nn::ConstantPad1dOptions({1, 2}, 2.)',
input_size=(2, 3, 4),
),
dict(
module_name='ConstantPad1d',
constructor_args=((1, 2), 2.),
cpp_constructor_args='torch::nn::ConstantPad1dOptions({1, 2}, 2.)',
input_size=(3, 4),
reference_fn=single_batch_reference_fn,
desc='batch',
),
dict(
module_name='ConstantPad1d',
constructor_args=((1, 2), 2.),
Expand Down