-
Notifications
You must be signed in to change notification settings - Fork 25k
add dtype checks for scatter/gather family of functions. #38646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add dtype checks for scatter/gather family of functions. #38646
Conversation
598896d
to
81a588a
Compare
💊 CI failures summary and remediationsAs of commit 86b4079 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 11 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, can you also please add a test that will catch a RunTime Error with mismatched input dtypes?
77fc948
to
4572a40
Compare
@nikitaved can you resolve the merge conflicts? |
test/test_torch.py
Outdated
# we ignore the case when src is Scalar, as it gets | ||
# cast via src.to<scalar_t>. | ||
if not is_scalar: | ||
with self.assertRaises(RuntimeError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you use assertRaisesRegexp, to make sure you are catching an error you are expecting to catch and not something else?
Let me know when you resolve the conflicts and it is ready for merging.
4572a40
to
86b4079
Compare
@ngimel , the conflicts resolved, the tests updated. I had to change the order because the test_bound test has a side effect. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
pytorch#38646 added checks for this, but only added tets for the scatter functions.
Adds additional dtype checks for scatter/gather family of functions, namely: 1. Checks whether `index` is of type `Long` 2. Checks whether `src.dtype == self.dtype`. This is a rather involved rework of pytorch#38646
) * add dtype checks for scatter/gather family of functions [1.5.1] Adds additional dtype checks for scatter/gather family of functions, namely: 1. Checks whether `index` is of type `Long` 2. Checks whether `src.dtype == self.dtype`. This is a rather involved rework of #38646 * Adjust test to match both TH and ATen exception patterns
Adds additional dtype checks for scatter/gather family of functions, namely:
index
is of typeLong
src.dtype == self.dtype
.Fixes #38554