Skip to content

add dtype checks for scatter/gather family of functions. #38646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

nikitaved
Copy link
Collaborator

Adds additional dtype checks for scatter/gather family of functions, namely:

  1. Checks whether index is of type Long
  2. Checks whether src.dtype == self.dtype.

Fixes #38554

@nikitaved nikitaved requested a review from ngimel May 18, 2020 13:49
@nikitaved nikitaved force-pushed the nikved/scatter_gather_dtype_check branch 2 times, most recently from 598896d to 81a588a Compare May 18, 2020 14:34
@dr-ci
Copy link

dr-ci bot commented May 18, 2020

💊 CI failures summary and remediations

As of commit 86b4079 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 11 times.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, can you also please add a test that will catch a RunTime Error with mismatched input dtypes?

@nikitaved nikitaved force-pushed the nikved/scatter_gather_dtype_check branch from 77fc948 to 4572a40 Compare May 19, 2020 10:51
@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 21, 2020
@gchanan
Copy link
Contributor

gchanan commented Jun 3, 2020

@nikitaved can you resolve the merge conflicts?

# we ignore the case when src is Scalar, as it gets
# cast via src.to<scalar_t>.
if not is_scalar:
with self.assertRaises(RuntimeError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you use assertRaisesRegexp, to make sure you are catching an error you are expecting to catch and not something else?
Let me know when you resolve the conflicts and it is ready for merging.

@nikitaved nikitaved force-pushed the nikved/scatter_gather_dtype_check branch from 4572a40 to 86b4079 Compare June 4, 2020 14:10
@nikitaved
Copy link
Collaborator Author

@ngimel , the conflicts resolved, the tests updated. I had to change the order because the test_bound test has a side effect.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@gchanan merged this pull request in e4f9c74.

gchanan added a commit to gchanan/pytorch that referenced this pull request Jun 9, 2020
pytorch#38646 added checks for this, but only added tets for the scatter functions.
facebook-github-bot pushed a commit that referenced this pull request Jun 9, 2020
Summary:
#38646 added checks for this, but only added tets for the scatter functions.
Pull Request resolved: #39689

Reviewed By: malfet

Differential Revision: D21945524

Pulled By: gchanan

fbshipit-source-id: 8b06856c06d6427b8cd929a1275422a5ed6e11cc
malfet added a commit to malfet/pytorch that referenced this pull request Jun 10, 2020
Adds additional dtype checks for scatter/gather family of functions, namely:
1. Checks whether `index` is of type `Long`
2. Checks whether `src.dtype == self.dtype`.

This is a rather involved rework of pytorch#38646
malfet added a commit that referenced this pull request Jun 10, 2020
)

* add dtype checks for scatter/gather family of functions [1.5.1]

Adds additional dtype checks for scatter/gather family of functions, namely:
1. Checks whether `index` is of type `Long`
2. Checks whether `src.dtype == self.dtype`.

This is a rather involved rework of #38646

* Adjust test to match both TH and ATen exception patterns
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Scatter is not checking that input and value type match
7 participants