Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have the ability to disable __torch_function__ dispatch for torch.nn.functional functions #55440

Open
jamesr66a opened this issue Apr 7, 2021 · 1 comment
Labels
module: nn Related to torch.nn module: __torch_function__ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jamesr66a
Copy link
Collaborator

jamesr66a commented Apr 7, 2021

torch.nn.Module instances and torch. namespace ops represent a pretty good split between high-level (stateful) blocks and low-level (stateless) operations. torch.nn.functional calls seem to just get in the way in the vast majority of cases that I've seen. The fact that these participate in __torch_function__ makes this weird middle-layer of abstraction that cannot be easily skipped. I can't even figure out how to skip them manually in a __torch_function__ handler, we can't redispatch into the functional with a layer of tensor-like objects stripped because then we wouldn't get the calls inside the functional (but not the functional itself). Can we have some way to turn off dispatching to __torch_function__ for torch.nn.functional calls?

(Additionally there's a separate issue where torch.nn.functional calls have this icky TorchScript boolean_dispatch thing

cc @hameerabbasi @rgommers @peterbell10 @albanD @mruberry @jbschlosser

@jamesr66a
Copy link
Collaborator Author

Note I have been able to hack something up by monkey-patching has_torch_function in torch.nn.functional:

        to_patch = ['has_torch_function', 'has_torch_function_unary', 'has_torch_function_variadic']
        try:
            def no(*args, **kwargs):
                return False 
            for name in to_patch:
                locals()[name] = getattr(torch.nn.functional, name)
                setattr(torch.nn.functional, name, no)

           <snip>
        finally:
            for name in to_patch:
                setattr(torch.nn.functional, name, locals()[name])

but this is bad, to say the least

@mrshenli mrshenli added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn module: __torch_function__ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants