Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opinfo] nn.functional.pad #62814

Closed
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 0 additions & 18 deletions test/test_nn.py
Expand Up @@ -13034,24 +13034,6 @@ def group_norm_ref(X, gamma, beta, groups, channels, eps):
@onlyOnCPUAndCUDA
@dtypes(torch.float64, torch.complex128)
def test_pad(self, device, dtype):
inputs = torch.randn(1, 3, 4, 4, device=device, dtype=dtype, requires_grad=True)
_assertGradAndGradgradChecks(self, lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL)
_assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1)), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL)
_assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1), value=2), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL)
self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='replicate'), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL))
self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='reflect'), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL))
self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='circular'), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL))

inputs = torch.randn(1, 2, 3, 4, 4, device=device, dtype=dtype, requires_grad=True)
self.assertTrue(gradcheck(lambda x: F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate'), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL))

# Assert assertion errors are raised for invalid circular padding values
inputs = torch.randn(1, 1, 4, device=device, dtype=dtype, requires_grad=True)
# Should raise error when trying to wrap around more than once
Expand Down
113 changes: 113 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -2642,6 +2642,73 @@ def generator():
return list(generator())


def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
assert mode in ('constant', 'reflect', 'replicate', 'circular')
if dtype == torch.bool and mode == 'circular':
# test_dtypes fails on ASAN with for the case below
# runtime error: load of value 190, which is not a valid value for type 'bool'
# Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug in PyTorch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. Was planning to open an issue to investigate later. Thanks for reminding.

shapes = ((1, 3), (2, 3, 3), (1, 3, 3), (3, 3, 3, 3), (3, 3, 5, 5))
pads = ((1, 2),)
negative_pad_case = ()
else:
# Supports 2-D, 3-D, 4-D, 5-D tensors
shapes = ((1, 3), (0, 3, 3), (1, 3, 3), (0, 3, 3, 3), (3, 3, 5, 5), (1, 3, 3, 3, 3)) # type: ignore[assignment]
pads = ((1, 2), (0, 1), (0, 2, 0, 1), (1, 1, 1, 1, 1, 1)) # type: ignore[assignment]
negative_pad_case = (
# (shape, pad)
((1, 3, 4, 4), (-1, 1, -2, 1)), # type: ignore[assignment]
)

make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

def generator():
if mode == 'constant':
# Default args
yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),))
Comment on lines +2757 to +2759
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for constant pad we only have a single test case? Maybe constant pad should be a separate sample_inputs_nn_pad_constant function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it's more than that.

Ref: #62814 (comment)


for shape, pad in chain(product(shapes, pads), negative_pad_case):
# Not all combinations of shapes and pads are valid
# Below are the checks to remove skip invalid combinations

# Function requires len(pad)/2 <= len(shape)
if not (len(pad) // 2 <= len(shape)):
continue

if mode in ['reflect', 'replicate', 'circular']:
input_dim = len(shape)
# Valid pad length for given input dim
if len(pad) == 2 and not (input_dim in (2, 3)):
continue
if len(pad) == 4 and not (input_dim in (3, 4)):
continue
if len(pad) == 6 and not (input_dim in (4, 5)):
continue

# Expected XD or YD (batch mode) tensor with possibly 0 batch size
# and other non-zero dimensions for input
if len(pad) == 2 and input_dim == 2 and shape[0] == 0:
continue
if len(pad) == 4 and input_dim == 3 and shape[0] == 0:
continue
if len(pad) == 6 and input_dim == 4 and shape[0] == 0:
continue

if mode == 'circular':
if not (len(pad) == 2 * (input_dim - 2)):
continue

if mode in ['reflect', 'replicate', 'circular']:
yield SampleInput(make_inp(shape), args=(pad, mode))
else:
for pad_value in (1., 2.):
yield SampleInput(make_inp(shape), args=(pad, mode, pad_value))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way this is set up makes it difficult to reason about exactly what the samples look like and how many there are. If there aren't too many samples, would it be better to just list them all? E.g. have a table of (shape, pad):

cases = [
    ((1, 3), (1, 2)),  # (shape, pad)
    ...
]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. It is hard to reason about. Will break it as cases. Thanks!

Below are the samples generated (will turn them to cases)

constant

torch.Size([1, 3, 3]) ((2, 2),)
torch.Size([1, 3]) ((1, 2), 'constant', 1.0)
torch.Size([1, 3]) ((1, 2), 'constant', 2.0)
torch.Size([1, 3]) ((0, 1), 'constant', 1.0)
torch.Size([1, 3]) ((0, 1), 'constant', 2.0)
torch.Size([1, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([1, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([0, 3, 3]) ((1, 2), 'constant', 1.0)
torch.Size([0, 3, 3]) ((1, 2), 'constant', 2.0)
torch.Size([0, 3, 3]) ((0, 1), 'constant', 1.0)
torch.Size([0, 3, 3]) ((0, 1), 'constant', 2.0)
torch.Size([0, 3, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([0, 3, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([0, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([0, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([1, 3, 3]) ((1, 2), 'constant', 1.0)
torch.Size([1, 3, 3]) ((1, 2), 'constant', 2.0)
torch.Size([1, 3, 3]) ((0, 1), 'constant', 1.0)
torch.Size([1, 3, 3]) ((0, 1), 'constant', 2.0)
torch.Size([1, 3, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([1, 3, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([1, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([1, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([0, 3, 3, 3]) ((1, 2), 'constant', 1.0)
torch.Size([0, 3, 3, 3]) ((1, 2), 'constant', 2.0)
torch.Size([0, 3, 3, 3]) ((0, 1), 'constant', 1.0)
torch.Size([0, 3, 3, 3]) ((0, 1), 'constant', 2.0)
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([0, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([0, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([3, 3, 5, 5]) ((1, 2), 'constant', 1.0)
torch.Size([3, 3, 5, 5]) ((1, 2), 'constant', 2.0)
torch.Size([3, 3, 5, 5]) ((0, 1), 'constant', 1.0)
torch.Size([3, 3, 5, 5]) ((0, 1), 'constant', 2.0)
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([3, 3, 5, 5]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([3, 3, 5, 5]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([1, 3, 3, 3, 3]) ((1, 2), 'constant', 1.0)
torch.Size([1, 3, 3, 3, 3]) ((1, 2), 'constant', 2.0)
torch.Size([1, 3, 3, 3, 3]) ((0, 1), 'constant', 1.0)
torch.Size([1, 3, 3, 3, 3]) ((0, 1), 'constant', 2.0)
torch.Size([1, 3, 3, 3, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([1, 3, 3, 3, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'constant', 1.0)
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'constant', 2.0)

reflect

torch.Size([1, 3]) ((1, 2), 'reflect')
torch.Size([1, 3]) ((0, 1), 'reflect')
torch.Size([0, 3, 3]) ((1, 2), 'reflect')
torch.Size([0, 3, 3]) ((0, 1), 'reflect')
torch.Size([1, 3, 3]) ((1, 2), 'reflect')
torch.Size([1, 3, 3]) ((0, 1), 'reflect')
torch.Size([1, 3, 3]) ((0, 2, 0, 1), 'reflect')
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'reflect')
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'reflect')
torch.Size([3, 3, 5, 5]) ((1, 1, 1, 1, 1, 1), 'reflect')
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'reflect')
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'reflect')

replicate

torch.Size([1, 3]) ((1, 2), 'replicate')
torch.Size([1, 3]) ((0, 1), 'replicate')
torch.Size([0, 3, 3]) ((1, 2), 'replicate')
torch.Size([0, 3, 3]) ((0, 1), 'replicate')
torch.Size([1, 3, 3]) ((1, 2), 'replicate')
torch.Size([1, 3, 3]) ((0, 1), 'replicate')
torch.Size([1, 3, 3]) ((0, 2, 0, 1), 'replicate')
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'replicate')
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'replicate')
torch.Size([3, 3, 5, 5]) ((1, 1, 1, 1, 1, 1), 'replicate')
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'replicate')
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'replicate')

circular

torch.Size([0, 3, 3]) ((1, 2), 'circular')
torch.Size([0, 3, 3]) ((0, 1), 'circular')
torch.Size([1, 3, 3]) ((1, 2), 'circular')
torch.Size([1, 3, 3]) ((0, 1), 'circular')
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'circular')
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'circular')
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'circular')
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'circular')

circular bool

torch.Size([2, 3, 3]) ((1, 2), 'circular')
torch.Size([1, 3, 3]) ((1, 2), 'circular')


samples = list(generator())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many samples does this actually generate?

assert len(samples) > 0
return samples


# TODO: reconcile with torch.linalg.det and torch.linalg.slogdet
# Creates matrices with a positive nonzero determinant
def sample_inputs_logdet(op_info, device, dtype, requires_grad, **kwargs):
Expand Down Expand Up @@ -6573,6 +6640,52 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_nn_activation_relu,
supports_out=False),
OpInfo('nn.functional.pad',
variant_test_name='constant',
aten_name='constant_pad_nd',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'),
supports_out=False),
OpInfo('nn.functional.pad',
variant_test_name='reflect',
dtypes=floating_and_complex_types(),
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'),
skips=(
# op name not found in JIT graph
# There are multiple aten ops, namely reflection_pad_{1,2,3}d
# so we can't use aten_name argument in opinfo
# RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
supports_out=False),
OpInfo('nn.functional.pad',
variant_test_name='replicate',
dtypes=floating_and_complex_types(),
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'),
skips=(
# op name not found in JIT graph
# There are multiple aten ops, namely replication_pad_{1,2,3}d
# so we can't use aten_name argument in opinfo
# RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
supports_out=False),
OpInfo('nn.functional.pad',
variant_test_name='circular',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'),
supports_forward_ad=True,
check_batched_grad=False,
skips=(
# Doesn't have a corresponding aten operator.
# RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
),
supports_out=False),
OpInfo('nn.functional.hardswish',
aten_name="hardswish",
supports_autograd=True,
Expand Down