New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[opinfo] nn.functional.pad #62814
[opinfo] nn.functional.pad #62814
Changes from 7 commits
d80b206
a98cd1f
79845ea
c8acdb1
8448edf
09186ca
2315447
1cd810d
987a094
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2642,6 +2642,73 @@ def generator(): | |
return list(generator()) | ||
|
||
|
||
def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs): | ||
assert mode in ('constant', 'reflect', 'replicate', 'circular') | ||
if dtype == torch.bool and mode == 'circular': | ||
# test_dtypes fails on ASAN with for the case below | ||
# runtime error: load of value 190, which is not a valid value for type 'bool' | ||
# Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562 | ||
shapes = ((1, 3), (2, 3, 3), (1, 3, 3), (3, 3, 3, 3), (3, 3, 5, 5)) | ||
pads = ((1, 2),) | ||
negative_pad_case = () | ||
else: | ||
# Supports 2-D, 3-D, 4-D, 5-D tensors | ||
shapes = ((1, 3), (0, 3, 3), (1, 3, 3), (0, 3, 3, 3), (3, 3, 5, 5), (1, 3, 3, 3, 3)) # type: ignore[assignment] | ||
pads = ((1, 2), (0, 1), (0, 2, 0, 1), (1, 1, 1, 1, 1, 1)) # type: ignore[assignment] | ||
negative_pad_case = ( | ||
# (shape, pad) | ||
((1, 3, 4, 4), (-1, 1, -2, 1)), # type: ignore[assignment] | ||
) | ||
|
||
make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) | ||
|
||
def generator(): | ||
if mode == 'constant': | ||
# Default args | ||
yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),)) | ||
Comment on lines
+2757
to
+2759
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for constant pad we only have a single test case? Maybe constant pad should be a separate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually it's more than that. Ref: #62814 (comment) |
||
|
||
for shape, pad in chain(product(shapes, pads), negative_pad_case): | ||
# Not all combinations of shapes and pads are valid | ||
# Below are the checks to remove skip invalid combinations | ||
|
||
# Function requires len(pad)/2 <= len(shape) | ||
if not (len(pad) // 2 <= len(shape)): | ||
continue | ||
|
||
if mode in ['reflect', 'replicate', 'circular']: | ||
input_dim = len(shape) | ||
# Valid pad length for given input dim | ||
if len(pad) == 2 and not (input_dim in (2, 3)): | ||
continue | ||
if len(pad) == 4 and not (input_dim in (3, 4)): | ||
continue | ||
if len(pad) == 6 and not (input_dim in (4, 5)): | ||
continue | ||
|
||
# Expected XD or YD (batch mode) tensor with possibly 0 batch size | ||
# and other non-zero dimensions for input | ||
if len(pad) == 2 and input_dim == 2 and shape[0] == 0: | ||
continue | ||
if len(pad) == 4 and input_dim == 3 and shape[0] == 0: | ||
continue | ||
if len(pad) == 6 and input_dim == 4 and shape[0] == 0: | ||
continue | ||
|
||
if mode == 'circular': | ||
if not (len(pad) == 2 * (input_dim - 2)): | ||
continue | ||
|
||
if mode in ['reflect', 'replicate', 'circular']: | ||
yield SampleInput(make_inp(shape), args=(pad, mode)) | ||
else: | ||
for pad_value in (1., 2.): | ||
yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way this is set up makes it difficult to reason about exactly what the samples look like and how many there are. If there aren't too many samples, would it be better to just list them all? E.g. have a table of (shape, pad):
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. It is hard to reason about. Will break it as cases. Thanks! Below are the samples generated (will turn them to cases) constant
reflect
replicate
circular
circular bool
|
||
|
||
samples = list(generator()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How many samples does this actually generate? |
||
assert len(samples) > 0 | ||
return samples | ||
|
||
|
||
# TODO: reconcile with torch.linalg.det and torch.linalg.slogdet | ||
# Creates matrices with a positive nonzero determinant | ||
def sample_inputs_logdet(op_info, device, dtype, requires_grad, **kwargs): | ||
|
@@ -6573,6 +6640,52 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): | |
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16), | ||
sample_inputs_func=sample_inputs_nn_activation_relu, | ||
supports_out=False), | ||
OpInfo('nn.functional.pad', | ||
variant_test_name='constant', | ||
aten_name='constant_pad_nd', | ||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), | ||
sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'), | ||
supports_out=False), | ||
OpInfo('nn.functional.pad', | ||
variant_test_name='reflect', | ||
dtypes=floating_and_complex_types(), | ||
dtypesIfCUDA=floating_and_complex_types_and(torch.half), | ||
sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'), | ||
skips=( | ||
# op name not found in JIT graph | ||
# There are multiple aten ops, namely reflection_pad_{1,2,3}d | ||
# so we can't use aten_name argument in opinfo | ||
# RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end() | ||
SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), | ||
), | ||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, | ||
supports_out=False), | ||
OpInfo('nn.functional.pad', | ||
variant_test_name='replicate', | ||
dtypes=floating_and_complex_types(), | ||
dtypesIfCUDA=floating_and_complex_types_and(torch.half), | ||
sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), | ||
skips=( | ||
# op name not found in JIT graph | ||
# There are multiple aten ops, namely replication_pad_{1,2,3}d | ||
# so we can't use aten_name argument in opinfo | ||
# RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end() | ||
SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), | ||
), | ||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, | ||
supports_out=False), | ||
OpInfo('nn.functional.pad', | ||
variant_test_name='circular', | ||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), | ||
sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'), | ||
supports_forward_ad=True, | ||
check_batched_grad=False, | ||
skips=( | ||
# Doesn't have a corresponding aten operator. | ||
# RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end() | ||
SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), | ||
), | ||
supports_out=False), | ||
OpInfo('nn.functional.hardswish', | ||
aten_name="hardswish", | ||
supports_autograd=True, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a bug in PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. Was planning to open an issue to investigate later. Thanks for reminding.