New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[opinfo] nn.functional.pad #62814
[opinfo] nn.functional.pad #62814
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 987a094 (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages: linux-bionic-py3.8-gcc9-coverage / test (default, 2, 2, linux.2xlarge) (1/2)Step: "Store PyTorch Test Reports" (full log | diagnosis details | 🔁 rerun)
|
ASAN Failure Debug Mode
|
CI Failures look unrelated. (will rebase post review) |
# test_dtypes fails on ASAN with for the case below | ||
# runtime error: load of value 190, which is not a valid value for type 'bool' | ||
# Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a bug in PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. Was planning to open an issue to investigate later. Thanks for reminding.
for pad_value in (1., 2.): | ||
yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) | ||
|
||
samples = list(generator()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How many samples does this actually generate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments and questions
def generator(): | ||
if mode == 'constant': | ||
# Default args | ||
yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),)) | ||
|
||
for shape, pad in chain(product(shapes, pads), negative_pad_case): | ||
# Not all combinations of shapes and pads are valid | ||
# Below are the checks to remove skip invalid combinations | ||
|
||
# Function requires len(pad)/2 <= len(shape) | ||
if not (len(pad) // 2 <= len(shape)): | ||
continue | ||
|
||
if mode in ['reflect', 'replicate', 'circular']: | ||
input_dim = len(shape) | ||
# Valid pad length for given input dim | ||
if len(pad) == 2 and not (input_dim in (2, 3)): | ||
continue | ||
if len(pad) == 4 and not (input_dim in (3, 4)): | ||
continue | ||
if len(pad) == 6 and not (input_dim in (4, 5)): | ||
continue | ||
|
||
# Expected XD or YD (batch mode) tensor with possibly 0 batch size | ||
# and other non-zero dimensions for input | ||
if len(pad) == 2 and input_dim == 2 and shape[0] == 0: | ||
continue | ||
if len(pad) == 4 and input_dim == 3 and shape[0] == 0: | ||
continue | ||
if len(pad) == 6 and input_dim == 4 and shape[0] == 0: | ||
continue | ||
|
||
if mode == 'circular': | ||
if not (len(pad) == 2 * (input_dim - 2)): | ||
continue | ||
|
||
if mode in ['reflect', 'replicate', 'circular']: | ||
yield SampleInput(make_inp(shape), args=(pad, mode)) | ||
else: | ||
for pad_value in (1., 2.): | ||
yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way this is set up makes it difficult to reason about exactly what the samples look like and how many there are. If there aren't too many samples, would it be better to just list them all? E.g. have a table of (shape, pad):
cases = [
((1, 3), (1, 2)), # (shape, pad)
...
]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. It is hard to reason about. Will break it as cases. Thanks!
Below are the samples generated (will turn them to cases)
constant
torch.Size([1, 3, 3]) ((2, 2),)
torch.Size([1, 3]) ((1, 2), 'constant', 1.0)
torch.Size([1, 3]) ((1, 2), 'constant', 2.0)
torch.Size([1, 3]) ((0, 1), 'constant', 1.0)
torch.Size([1, 3]) ((0, 1), 'constant', 2.0)
torch.Size([1, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([1, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([0, 3, 3]) ((1, 2), 'constant', 1.0)
torch.Size([0, 3, 3]) ((1, 2), 'constant', 2.0)
torch.Size([0, 3, 3]) ((0, 1), 'constant', 1.0)
torch.Size([0, 3, 3]) ((0, 1), 'constant', 2.0)
torch.Size([0, 3, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([0, 3, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([0, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([0, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([1, 3, 3]) ((1, 2), 'constant', 1.0)
torch.Size([1, 3, 3]) ((1, 2), 'constant', 2.0)
torch.Size([1, 3, 3]) ((0, 1), 'constant', 1.0)
torch.Size([1, 3, 3]) ((0, 1), 'constant', 2.0)
torch.Size([1, 3, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([1, 3, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([1, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([1, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([0, 3, 3, 3]) ((1, 2), 'constant', 1.0)
torch.Size([0, 3, 3, 3]) ((1, 2), 'constant', 2.0)
torch.Size([0, 3, 3, 3]) ((0, 1), 'constant', 1.0)
torch.Size([0, 3, 3, 3]) ((0, 1), 'constant', 2.0)
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([0, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([0, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([3, 3, 5, 5]) ((1, 2), 'constant', 1.0)
torch.Size([3, 3, 5, 5]) ((1, 2), 'constant', 2.0)
torch.Size([3, 3, 5, 5]) ((0, 1), 'constant', 1.0)
torch.Size([3, 3, 5, 5]) ((0, 1), 'constant', 2.0)
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([3, 3, 5, 5]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([3, 3, 5, 5]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([1, 3, 3, 3, 3]) ((1, 2), 'constant', 1.0)
torch.Size([1, 3, 3, 3, 3]) ((1, 2), 'constant', 2.0)
torch.Size([1, 3, 3, 3, 3]) ((0, 1), 'constant', 1.0)
torch.Size([1, 3, 3, 3, 3]) ((0, 1), 'constant', 2.0)
torch.Size([1, 3, 3, 3, 3]) ((0, 2, 0, 1), 'constant', 1.0)
torch.Size([1, 3, 3, 3, 3]) ((0, 2, 0, 1), 'constant', 2.0)
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 1.0)
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'constant', 2.0)
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'constant', 1.0)
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'constant', 2.0)
reflect
torch.Size([1, 3]) ((1, 2), 'reflect')
torch.Size([1, 3]) ((0, 1), 'reflect')
torch.Size([0, 3, 3]) ((1, 2), 'reflect')
torch.Size([0, 3, 3]) ((0, 1), 'reflect')
torch.Size([1, 3, 3]) ((1, 2), 'reflect')
torch.Size([1, 3, 3]) ((0, 1), 'reflect')
torch.Size([1, 3, 3]) ((0, 2, 0, 1), 'reflect')
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'reflect')
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'reflect')
torch.Size([3, 3, 5, 5]) ((1, 1, 1, 1, 1, 1), 'reflect')
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'reflect')
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'reflect')
replicate
torch.Size([1, 3]) ((1, 2), 'replicate')
torch.Size([1, 3]) ((0, 1), 'replicate')
torch.Size([0, 3, 3]) ((1, 2), 'replicate')
torch.Size([0, 3, 3]) ((0, 1), 'replicate')
torch.Size([1, 3, 3]) ((1, 2), 'replicate')
torch.Size([1, 3, 3]) ((0, 1), 'replicate')
torch.Size([1, 3, 3]) ((0, 2, 0, 1), 'replicate')
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'replicate')
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'replicate')
torch.Size([3, 3, 5, 5]) ((1, 1, 1, 1, 1, 1), 'replicate')
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'replicate')
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'replicate')
circular
torch.Size([0, 3, 3]) ((1, 2), 'circular')
torch.Size([0, 3, 3]) ((0, 1), 'circular')
torch.Size([1, 3, 3]) ((1, 2), 'circular')
torch.Size([1, 3, 3]) ((0, 1), 'circular')
torch.Size([0, 3, 3, 3]) ((0, 2, 0, 1), 'circular')
torch.Size([3, 3, 5, 5]) ((0, 2, 0, 1), 'circular')
torch.Size([1, 3, 3, 3, 3]) ((1, 1, 1, 1, 1, 1), 'circular')
torch.Size([1, 3, 4, 4]) ((-1, 1, -2, 1), 'circular')
circular bool
torch.Size([2, 3, 3]) ((1, 2), 'circular')
torch.Size([1, 3, 3]) ((1, 2), 'circular')
if mode == 'constant': | ||
# Default args | ||
yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for constant pad we only have a single test case? Maybe constant pad should be a separate sample_inputs_nn_pad_constant
function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it's more than that.
Ref: #62814 (comment)
@zou3519 have mostly addressed the review. This PR is ready for another round. Thanks! |
((1, 3), (1, 2)), | ||
((1, 3), (0, 1)), | ||
((0, 3, 3), (1, 2)), | ||
((0, 3, 3), (0, 1)), | ||
((1, 3, 3), (1, 2)), | ||
((1, 3, 3), (0, 1)), | ||
((1, 3, 3), (0, 2, 0, 1)), | ||
((0, 3, 3, 3), (0, 2, 0, 1)), | ||
((3, 3, 5, 5), (0, 2, 0, 1)), | ||
((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), | ||
((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), | ||
((1, 3, 4, 4), (-1, 1, -2, 1)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this many cases? For example, we probably only need a single case for batch-size 0 and we probably don't need so many cases to test that pad 0 works
((0, 3, 3, 3), (0, 2, 0, 1)), | ||
((3, 3, 5, 5), (0, 2, 0, 1)), | ||
((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), | ||
((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
((1, 3), (1, 2)), | ||
((1, 3), (0, 1)), | ||
((1, 3), (0, 2, 0, 1)), | ||
((0, 3, 3), (1, 2)), | ||
((0, 3, 3), (0, 1)), | ||
((0, 3, 3), (0, 2, 0, 1)), | ||
((0, 3, 3), (1, 1, 1, 1, 1, 1)), | ||
((1, 3, 3), (1, 2)), | ||
((1, 3, 3), (0, 1)), | ||
((1, 3, 3), (0, 2, 0, 1)), | ||
((1, 3, 3), (1, 1, 1, 1, 1, 1)), | ||
((0, 3, 3, 3), (1, 2)), | ||
((0, 3, 3, 3), (0, 1)), | ||
((0, 3, 3, 3), (0, 2, 0, 1)), | ||
((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)), | ||
((3, 3, 5, 5), (1, 2)), | ||
((3, 3, 5, 5), (0, 1)), | ||
((3, 3, 5, 5), (0, 2, 0, 1)), | ||
((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), | ||
((1, 3, 3, 3, 3), (1, 2)), | ||
((1, 3, 3, 3, 3), (0, 1)), | ||
((1, 3, 3, 3, 3), (0, 2, 0, 1)), | ||
((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), | ||
((1, 3, 4, 4), (-1, 1, -2, 1)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto here: do we need this many cases?
for shape, pad in cases: | ||
yield SampleInput(make_inp(shape), args=(pad, mode)) | ||
else: # mode == 'constant' | ||
for pad_value in (1., 2.): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be good to just put this value into the above shape dict and test 1 or 2 as the value for each case. Otherwise this means that we duplicate all of the above tests, and there are indeed a lot of tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm worried that we're testing too much here, but maybe the complexity of this operation requires that we test this much -- what do you think?
Yeah. I agree it is a lot of test. But thought that this would be helpful as But if this feels like a lot of cases, I'll be happy to trim them. |
Gentle Ping :) @zou3519 |
Let's keep them for now, we can always trim them down later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you!
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Reference: pytorch/functorch#78 Pull Request resolved: #62814 Reviewed By: VitalyFedyunin Differential Revision: D30307492 Pulled By: zou3519 fbshipit-source-id: 4f6062eb4a3c91ed1795df1f82846afa0abafcdc
Reference: pytorch/functorch#78