-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Port dilated_max_pool2d() to ATen #20691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ted_max_pool2d
…ted_max_pool2d
check_dim_size(gradOutput, ndim, ndim-2, outputHeight); | ||
check_dim_size(gradOutput, ndim, ndim-1, outputWidth); | ||
|
||
if (cuda) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I... would believe you if you told me this is what the code did before, but this special case is kind of shocking XD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is shocking, but it's indeed what the old code did. :) The CUDA code here always uses 4 dimensions in the kernel for output
, then resizes back after the kernel:
if(input->dim() == 3)
THCTensor_(resize3d)(state, output, nInputPlane, nOutputRows, nOutputCols);
The C code uses 3 or 4 dimensions for the kernel.
Perhaps the shape check should be after all resizes and just before the kernel, but the old code checked right at the top of the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that's the correct fix. No need to block this PR on fixing it, would be a nice readability improvement though.
That is much appreciated, thank you! |
"padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH); | ||
|
||
if (outputWidth < 1 || outputHeight < 1) { | ||
AT_ERROR("Given input size: (", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1, msg)
IntArrayRef dilation, | ||
bool ceil_mode) | ||
{ | ||
// XXX JIT: Pooling.cpp allows stride.empty(). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds like another misannotated function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The annotation Optional[BroadcastingList2[int]]
looks correct, but there's some explicit handling of stride
here:
pytorch/torch/nn/functional.py
Line 482 in 70ecddf
if stride is None: |
I wonder if that could just be:
if stride is None:
stride = kernel_size
The drawback is that one could crash cpp
code by modifying Python code, but that is also the case for the annotation.
(Edit) "crash" meaning triggering an exception here, unless it will be changed to a real assert()
in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other special case is here:
pytorch/test/cpp/api/integration.cpp
Line 303 in 70ecddf
x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu(); |
It ends up calling max_pool2d_with_indices
with all default parameters.
I didn't look very hard, but I couldn't find the overload of torch::max_pool2d(conv1->forward(x), {2, 2})
that works with just two parameters in the cpp
files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The drawback is that one could crash cpp code by modifying Python code, but that is also the case for the annotation.
I think, for now, we should consider these annotations as part of the TCB. I do wonder a little if they can't be checked for consistency with the main annotations, but... well... someone would have to figure that out :>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if that could just be:
Then we'd have to explain why it was permissible for all the other sites to also default things to empty list. Maybe they're all wrong.
@ailzhang, you helped us resolve the annotation last time in #20306. Do you know what's going on with these Optional[BroadcastingList2]
things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ailzhang, I think the stride.empty()
case is also triggered by the regular Python tests if that is more convenient.
IntegrationTest.MNIST
triggers all of stride.empty() && padding.size()==1 && dilation.size()==1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And (for the sake of completeness) this PR already contains a workaround, so assert()
will need to be added to reproduce it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang I have a commit to fix this(and other similar ones). Once this PR is merged into master, I will rebase and send it out. :D
|
||
template <typename dest_t, typename src_t> | ||
static inline dest_t | ||
safe_downcast(src_t v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
You forgot to delete |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete file
Done, I'll leave the "allow" button on next time, it's just a habit of mine from other projects. :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@pytorchbot retest this please. |
Summary: Pull Request resolved: pytorch/pytorch#20691 Differential Revision: D15435960 Pulled By: ezyang fbshipit-source-id: 548b7cc42e52ad2c641ec7d9cf78028d9411d02e
No description provided.