Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a private API that lists overridable functions #33182

Closed
ngoldbaum opened this issue Feb 11, 2020 · 3 comments
Closed

Add a private API that lists overridable functions #33182

ngoldbaum opened this issue Feb 11, 2020 · 3 comments
Assignees
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: tests Issues related to tests (not the torch.testing module) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ngoldbaum
Copy link
Contributor

ngoldbaum commented Feb 11, 2020

🚀 Feature

Add a way for library authors who depend on pytorch to get a list of all functions that they would need to implement to fully wrap the API surface that is overridable via __torch_function__.

Motivation

See e.g. #27064 (comment).

Pitch

Currently it's not straightforward for authors of types that implement __torch_function__ to test that they fully cover the API that is overridable via __torch_function__. It's important that they are able to wrap as much of the API as possible because once you have written a type that implements __torch_function__ users of that type will see a TypeError if the __torch_function__ implementation returns NotImplemented for unwrapped functions.

It would be nice if there were some private API that gave developers access to the functions that are overridable. It might also be nice to include the function signatures in an easily introspectable format as that makes it easier to write tests for the overrides.

Alternatives

It might also be sufficient to just have a list of functions in the documentation.

@cpuhrsch
Copy link
Contributor

On the list of functions in the documentation: I think it's sufficient, if it'll be kept up to date. A private API that's also used in testing could be a good forcing function to make sure that list is always complete.

@ngoldbaum ngoldbaum changed the title Add a private API of overridable functions Add a private API that lists overridable functions Feb 11, 2020
@gchanan gchanan added module: tests Issues related to tests (not the torch.testing module) enhancement Not as big of a feature, but technically not a bug. Should be easy to fix needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 11, 2020
@gchanan gchanan removed the needs research We need to decide whether or not this merits inclusion, based on research world label Feb 18, 2020
@gchanan
Copy link
Contributor

gchanan commented Feb 18, 2020

we think this is reasonable and would accept a PR that implements it.

@ngoldbaum ngoldbaum self-assigned this Feb 18, 2020
@ngoldbaum
Copy link
Contributor Author

I'm going to do this but I'll wait for #32799 to land before putting in a PR that fixes this.

ttumiel pushed a commit to ttumiel/pytorch that referenced this issue Mar 4, 2020
…rch#33791)

Summary:
Fixes pytorch#33182

This adds private API functions that developers of types that implement `__torch_function__` can use to ensure full coverage of the subset of the PyTorch API that can be overrided.

I've refactored some of the code in the tests into a new `torch._overrides.get_overridable_functions` function. I've also changed `TENSOR_LIKE_TORCH_OVERRIDES` into `torch._overrides.get_testing_overrides` and `IGNORED_TORCH_FUNCTIONS` into `torch._overrides.get_ignored_functions`. Making these two static global variables in the tests into functions should allow rewriting their implementation to construct their return values instead of just statically defining the return value as is done here. Currently that is blocked on not being able to inspect function signatures of compiled kernels in PyTorch (see pytorch#28233). See the docs I've added for usage examples of these new functions. I also refactored the existing override tests to make use of these new functions, which should be a good forcing function to make sure they're kept up-to-date.

Finally, while working on this I discovered that `TestTorchFunctionOverrides.test_mean` and `TestTorchFunctionOverrides.test_mm` weren't ever being run because they were getting clobbered by the other dynamically generated override tests. I fixed that by renaming the tests and then fixing the actual test code. I've verified that all the subclassing semantics is correct and that the updated test answers are correct. I'm happy to put the fixes to the existing tests in as a separate pull request if that would be easier to review.

ping cpuhrsch since the feature request originally came from them.
Pull Request resolved: pytorch#33791

Differential Revision: D20195053

Pulled By: cpuhrsch

fbshipit-source-id: 1585f4e405f5223932b410eae03a288dc8eb627e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: tests Issues related to tests (not the torch.testing module) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants