Skip to content

Commit

Permalink
[opinfo] nn.functional.pad (#62814)
Browse files Browse the repository at this point in the history
Summary:
Reference: pytorch/functorch#78

Pull Request resolved: #62814

Reviewed By: VitalyFedyunin

Differential Revision: D30307492

Pulled By: zou3519

fbshipit-source-id: 4f6062eb4a3c91ed1795df1f82846afa0abafcdc
  • Loading branch information
kshitij12345 authored and alanwaketan committed Aug 17, 2021
1 parent 8a30c40 commit 85c3576
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 18 deletions.
18 changes: 0 additions & 18 deletions test/test_nn.py
Expand Up @@ -13066,24 +13066,6 @@ def group_norm_ref(X, gamma, beta, groups, channels, eps):
@onlyOnCPUAndCUDA
@dtypes(torch.float64, torch.complex128)
def test_pad(self, device, dtype):
inputs = torch.randn(1, 3, 4, 4, device=device, dtype=dtype, requires_grad=True)
_assertGradAndGradgradChecks(self, lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL)
_assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1)), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL)
_assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1), value=2), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL)
self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='replicate'), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL))
self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='reflect'), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL))
self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='circular'), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL))

inputs = torch.randn(1, 2, 3, 4, 4, device=device, dtype=dtype, requires_grad=True)
self.assertTrue(gradcheck(lambda x: F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate'), (inputs,),
nondet_tol=GRADCHECK_NONDET_TOL))

# Assert assertion errors are raised for invalid circular padding values
inputs = torch.randn(1, 1, 4, device=device, dtype=dtype, requires_grad=True)
# Should raise error when trying to wrap around more than once
Expand Down
130 changes: 130 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -2765,6 +2765,90 @@ def generator():
return list(generator())


def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
assert mode in ('constant', 'reflect', 'replicate', 'circular')
if mode in ['reflect', 'replicate']:
cases: tuple = ( # ignore
((1, 3), (1, 2)),
((1, 3), (0, 1)),
((0, 3, 3), (1, 2)),
((0, 3, 3), (0, 1)),
((1, 3, 3), (1, 2)),
((1, 3, 3), (0, 1)),
((1, 3, 3), (0, 2, 0, 1)),
((0, 3, 3, 3), (0, 2, 0, 1)),
((3, 3, 5, 5), (0, 2, 0, 1)),
((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
((1, 3, 4, 4), (-1, 1, -2, 1)),
)
elif mode == 'constant':
cases = (
((1, 3), (1, 2)),
((1, 3), (0, 1)),
((1, 3), (0, 2, 0, 1)),
((0, 3, 3), (1, 2)),
((0, 3, 3), (0, 1)),
((0, 3, 3), (0, 2, 0, 1)),
((0, 3, 3), (1, 1, 1, 1, 1, 1)),
((1, 3, 3), (1, 2)),
((1, 3, 3), (0, 1)),
((1, 3, 3), (0, 2, 0, 1)),
((1, 3, 3), (1, 1, 1, 1, 1, 1)),
((0, 3, 3, 3), (1, 2)),
((0, 3, 3, 3), (0, 1)),
((0, 3, 3, 3), (0, 2, 0, 1)),
((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
((3, 3, 5, 5), (1, 2)),
((3, 3, 5, 5), (0, 1)),
((3, 3, 5, 5), (0, 2, 0, 1)),
((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
((1, 3, 3, 3, 3), (1, 2)),
((1, 3, 3, 3, 3), (0, 1)),
((1, 3, 3, 3, 3), (0, 2, 0, 1)),
((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
((1, 3, 4, 4), (-1, 1, -2, 1)),
)
else: # mode == 'circular'
if dtype == torch.bool:
# test_dtypes fails on ASAN with for the case ab
# runtime error: load of value 190, which is not a valid value for type 'bool'
# Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562
# Reference Issue: https://github.com/pytorch/pytorch/issues/63034
cases = (
((2, 3, 3), (1, 2)),
((1, 3, 3), (1, 2)),
)
else:
cases = (
((0, 3, 3), (1, 2)),
((0, 3, 3), (0, 1)),
((1, 3, 3), (1, 2)),
((1, 3, 3), (0, 1)),
((0, 3, 3, 3), (0, 2, 0, 1)),
((3, 3, 5, 5), (0, 2, 0, 1)),
((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
((1, 3, 4, 4), (-1, 1, -2, 1)),
)

make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

def generator():
if mode == 'constant':
# Default args
yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),))

if mode in ['reflect', 'replicate', 'circular']:
for shape, pad in cases:
yield SampleInput(make_inp(shape), args=(pad, mode))
else: # mode == 'constant'
for pad_value in (1., 2.):
for shape, pad in cases:
yield SampleInput(make_inp(shape), args=(pad, mode, pad_value))

return list(generator())


# TODO: reconcile with torch.linalg.det and torch.linalg.slogdet
# Creates matrices with a positive nonzero determinant
def sample_inputs_logdet(op_info, device, dtype, requires_grad, **kwargs):
Expand Down Expand Up @@ -6801,6 +6885,52 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
SkipInfo('TestJit', 'test_variant_consistency_jit'),
),
supports_out=False,),
OpInfo('nn.functional.pad',
variant_test_name='constant',
aten_name='constant_pad_nd',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'),
supports_out=False),
OpInfo('nn.functional.pad',
variant_test_name='reflect',
dtypes=floating_and_complex_types(),
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'),
skips=(
# op name not found in JIT graph
# There are multiple aten ops, namely reflection_pad_{1,2,3}d
# so we can't use aten_name argument in opinfo
# RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
supports_out=False),
OpInfo('nn.functional.pad',
variant_test_name='replicate',
dtypes=floating_and_complex_types(),
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'),
skips=(
# op name not found in JIT graph
# There are multiple aten ops, namely replication_pad_{1,2,3}d
# so we can't use aten_name argument in opinfo
# RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
supports_out=False),
OpInfo('nn.functional.pad',
variant_test_name='circular',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'),
supports_forward_ad=True,
check_batched_grad=False,
skips=(
# Doesn't have a corresponding aten operator.
# RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
),
supports_out=False),
OpInfo('nn.functional.hardswish',
aten_name="hardswish",
supports_autograd=True,
Expand Down

0 comments on commit 85c3576

Please sign in to comment.