Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API for listing functions overridable by __torch_function__ #33791

Closed
wants to merge 4 commits into from

Conversation

ngoldbaum
Copy link
Contributor

Fixes #33182

This adds private API functions that developers of types that implement __torch_function__ can use to ensure full coverage of the subset of the PyTorch API that can be overrided.

I've refactored some of the code in the tests into a new torch._overrides.get_overridable_functions function. I've also changed TENSOR_LIKE_TORCH_OVERRIDES into torch._overrides.get_testing_overrides and IGNORED_TORCH_FUNCTIONS into torch._overrides.get_ignored_functions. Making these two static global variables in the tests into functions should allow rewriting their implementation to construct their return values instead of just statically defining the return value as is done here. Currently that is blocked on not being able to inspect function signatures of compiled kernels in PyTorch (see #28233). See the docs I've added for usage examples of these new functions. I also refactored the existing override tests to make use of these new functions, which should be a good forcing function to make sure they're kept up-to-date.

Finally, while working on this I discovered that TestTorchFunctionOverrides.test_mean and TestTorchFunctionOverrides.test_mm weren't ever being run because they were getting clobbered by the other dynamically generated override tests. I fixed that by renaming the tests and then fixing the actual test code. I've verified that all the subclassing semantics is correct and that the updated test answers are correct. I'm happy to put the fixes to the existing tests in as a separate pull request if that would be easier to review.

ping @cpuhrsch since the feature request originally came from them.

@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 26, 2020
@ezyang ezyang requested a review from cpuhrsch February 27, 2020 04:38
@ezyang
Copy link
Contributor

ezyang commented Feb 27, 2020

@cpuhrsch I'm deferring this code review to you. @ngoldbaum let me know if there's anything specific you want me to look at.

test/test_overrides.py Outdated Show resolved Hide resolved
functions, however in practice this is not enough to write tests for all of
these functions without laboriously and manually copying the signature of each
function for each test. To ease this process, the
``torch._overrides.get_testing_overrides`` function returns a dictionary mapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see in the test below you exercise the torch.mean function, but there are two overloads for torch.mean

torch.mean(input) → Tensor

and

torch.mean(input, dim, keepdim=False, out=None) → Tensor

How is the second one tests?

Further we have our own argparser, which means the Python signature of a given function might not accurately capture all capabilities.

How is this dealt with?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean is a C++ kernel, so its handled via the C++ argument parsing, e.g.

if (THPVariable_Check(obj)) {
if (check_has_torch_function(obj)) {
append_overloaded_arg(overloaded_args, obj);
}
return true;
.

I don't think there's a direct test for both signatures in the override tests. In principle we could add it but I'm not sure if it would be worth the extra boilerplate.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! I added a few nits and a general question, but afterwards I think this can be approved and landed even if just treated as a code refactor / abstraction.

@ngoldbaum
Copy link
Contributor Author

Fixed the nits, thank you for the copyediting :)

@dr-ci
Copy link

dr-ci bot commented Feb 27, 2020

💊 CircleCI build failures summary and remediations

As of commit c5f7ea3:

None of the build failures appear to be your fault.

  • 2/2 broken upstream at merge base 095de1e since Feb 27

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch origin viable/strict
    git rebase --onto viable/strict $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch origin viable/strict
    git rebase viable/strict
    

    Check out the recency history of this "viable master" tracking branch.

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🚧 2 upstream failures recognized by patterns:

These builds matched patterns, but were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 1 time.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! I'll adopt this within pytorch/nestedtensor as soon as it lands for testing.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@cpuhrsch merged this pull request in ad2825a.

@mrshenli
Copy link
Contributor

mrshenli commented Mar 4, 2020

Looks like pytorch_linux_xenial_py3_5_test starts to become flaky after this test.

Mar 03 23:09:46 ======================================================================
Mar 03 23:09:46 ERROR: test_sigmoid (__main__.TestTorchFunctionOverride)
Mar 03 23:09:46 ----------------------------------------------------------------------
Mar 03 23:09:46 Traceback (most recent call last):
Mar 03 23:09:46   File "test_overrides.py", line 478, in test
Mar 03 23:09:46     self.assertEqual(func(*func_args), -1)
Mar 03 23:09:46   File "/opt/python/3.5/lib/python3.5/site-packages/torch/nn/functional.py", line 1569, in sigmoid
Mar 03 23:09:46     return input.sigmoid()
Mar 03 23:09:46 AttributeError: 'TensorLike' object has no attribute 'sigmoid'
Mar 03 23:09:46 
Mar 03 23:09:46 ======================================================================
Mar 03 23:09:46 ERROR: test_tanh (__main__.TestTorchFunctionOverride)
Mar 03 23:09:46 ----------------------------------------------------------------------
Mar 03 23:09:46 Traceback (most recent call last):
Mar 03 23:09:46   File "test_overrides.py", line 478, in test
Mar 03 23:09:46     self.assertEqual(func(*func_args), -1)
Mar 03 23:09:46   File "/opt/python/3.5/lib/python3.5/site-packages/torch/nn/functional.py", line 1558, in tanh
Mar 03 23:09:46     return input.tanh()
Mar 03 23:09:46 AttributeError: 'TensorLike' object has no attribute 'tanh'

https://app.circleci.com/jobs/github/pytorch/pytorch/4678835
https://app.circleci.com/jobs/github/pytorch/pytorch/4678098

@ngoldbaum
Copy link
Contributor Author

@mrshenli the fix for that is to add __torch_function__ handling for those two functions. Would you like me to put in a fix?

@mrshenli
Copy link
Contributor

mrshenli commented Mar 4, 2020

Hey @ngoldbaum, thanks a lot for prompt reply.

@mrshenli the fix for that is to add torch_function handling for those two functions. Would you like me to put in a fix?

That will be great! Please send the fix and I will stamp. Thanks!!

@mrshenli
Copy link
Contributor

mrshenli commented Mar 4, 2020

@ngoldbaum BTW, do you know why it is flaky instead of fail consistently. Below is a passing run:

https://app.circleci.com/jobs/github/pytorch/pytorch/4685311

@ngoldbaum
Copy link
Contributor Author

I'm not sure but if I had to guess it's because of dict order randomization on Python3.5 and something else is clobbering those entries. I'm going to see if I can figure out exactly why right now...

@mrshenli
Copy link
Contributor

mrshenli commented Mar 4, 2020

Hey @ngoldbaum, since we need to wait for 5-6 hours for the CI to pass before we can land the fix, do you mind if I revert this PR for now, and you can send in the resubmit together with the fix?

@ngoldbaum
Copy link
Contributor Author

Go ahead, sorry for the trouble!

@mrshenli
Copy link
Contributor

mrshenli commented Mar 4, 2020

Thanks! No worries :)

ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
…rch#33791)

Summary:
Fixes pytorch#33182

This adds private API functions that developers of types that implement `__torch_function__` can use to ensure full coverage of the subset of the PyTorch API that can be overrided.

I've refactored some of the code in the tests into a new `torch._overrides.get_overridable_functions` function. I've also changed `TENSOR_LIKE_TORCH_OVERRIDES` into `torch._overrides.get_testing_overrides` and `IGNORED_TORCH_FUNCTIONS` into `torch._overrides.get_ignored_functions`. Making these two static global variables in the tests into functions should allow rewriting their implementation to construct their return values instead of just statically defining the return value as is done here. Currently that is blocked on not being able to inspect function signatures of compiled kernels in PyTorch (see pytorch#28233). See the docs I've added for usage examples of these new functions. I also refactored the existing override tests to make use of these new functions, which should be a good forcing function to make sure they're kept up-to-date.

Finally, while working on this I discovered that `TestTorchFunctionOverrides.test_mean` and `TestTorchFunctionOverrides.test_mm` weren't ever being run because they were getting clobbered by the other dynamically generated override tests. I fixed that by renaming the tests and then fixing the actual test code. I've verified that all the subclassing semantics is correct and that the updated test answers are correct. I'm happy to put the fixes to the existing tests in as a separate pull request if that would be easier to review.

ping cpuhrsch since the feature request originally came from them.
Pull Request resolved: pytorch#33791

Differential Revision: D20195053

Pulled By: cpuhrsch

fbshipit-source-id: 1585f4e405f5223932b410eae03a288dc8eb627e
facebook-github-bot pushed a commit that referenced this pull request Mar 12, 2020
…n__" (#34240)

Summary:
This is a redo of #33791, which was reverted because it introduced a flaky test. The test was flaky and only flaky on Python3.5 because of dict order randomization.

I've fixed the issue with tests clobbering each other in b539fec and removed the override tests for `torch.nn.functional.tanh` and `torch.nn.functional.sigmoid`, which are deprecated and shouldn't be overridable in e0d7402. I also verified that no more test clobbering is happening.
Pull Request resolved: #34240

Differential Revision: D20252442

Pulled By: cpuhrsch

fbshipit-source-id: 069568e342a41c90e1dc76cbf85ba4aed47f24be
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a private API that lists overridable functions
7 participants