-
Notifications
You must be signed in to change notification settings - Fork 25.1k
[fix] ReduceOps throw error if dim is repeated #44281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Should this be made into some reusable helper? So it's more discoverable if a similar check is decided for other ops |
💊 CI failures summary and remediationsAs of commit 5f4c242 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 27 times. |
Makes sense. |
aten/src/ATen/native/ReduceOps.cpp
Outdated
@@ -358,6 +358,9 @@ Tensor& prod_out(Tensor& result, const Tensor& self, Dimname dim, | |||
|
|||
Tensor &mean_out_cpu_gpu(Tensor &result, const Tensor &self, IntArrayRef dim, | |||
bool keepdim, c10::optional<ScalarType> opt_dtype) { | |||
auto dim_set = std::set<IntArrayRef::value_type>(dim.cbegin(), dim.cend()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably not actually the most efficient way to do the duplicate check, since it involves doing a fairly unnecessary dynamic allocation for the set. Probably quickest when number of dims is small (which it should be usually) is just the quadratic nested loops version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.
What I had in mind was something like
if (dim.size() < 10) { // 10 is just some heurestic value
// for loop version
} else {
// set version.
}
Let me know if it sounds good and if you approve it what should the heuristic value be?
Thank You!
test/test_torch.py
Outdated
def test_mean_repeated_dim(self, device): | ||
x = torch.randn(3, 3, 3, 3, device=device) | ||
with self.assertRaisesRegex(RuntimeError, r'mean: repeated dimension in `dim` \(\[0, 0\]\)'): | ||
torch.mean(x, dim=(0, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry So, as was mentioned in the original issue, there are a bunch of operators which take in a list of dimensions. It seems like it would be useful to easily run a version of this test for all of the operators that do this :>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I've made a note.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had quickly searched the file below
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
for int[1]
based on the signature of
- func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Codecov Report
@@ Coverage Diff @@
## master #44281 +/- ##
=======================================
Coverage 69.24% 69.24%
=======================================
Files 381 381
Lines 47573 47573
=======================================
+ Hits 32942 32944 +2
+ Misses 14631 14629 -2
Continue to review full report at Codecov.
|
* test more ops. * test repeated negative dim.
torch.mean
throw error if dim is repeatedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Ooh, this new version is much better, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
cc @ailzhang on the xla bit |
Gentle Ping :) |
Fixes #44273
TODO