-
Notifications
You must be signed in to change notification settings - Fork 25.1k
__torch_function__ mode #75154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
__torch_function__ mode #75154
Conversation
Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 31df729 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
cc'ing @mruberry as you may find the API useful |
@@ -278,6 +278,12 @@ auto handle_torch_function_no_python_arg_parser( | |||
} | |||
if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) { | |||
for (auto& arg : overloaded_args) { | |||
// See https://github.com/pytorch/pytorch/issues/63767 | |||
if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__").is(arg)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This uses torch_function
so needs to be below where that variable is set.
`__torch_function__` mode brings `enable_python_mode` functionality to the `__torch_function__` level of overrides (at the Python API level). I intend to use this to collect test data for validating abstract interpreters we plan to write for operators, but there are many other possible applications (for example, tracers like FX could use this to track factory function calls without monkey-patching the PyTorch namespace, and the mode could be used to vary backward semantics of functions without having to add new arguments to them, see https://github.com/albanD/subclass_zoo/blob/main/inner_autograd_tensor.py) This PR also gives a design for mode stacks which is implemented in Python (rather in some sort of C++ infrastructure). My goal here was to trade off performance so that I could simplify the C++ bindings (we only expose `_set_torch_function_mode` and `_get_torch_function_mode` in the bindings) and iterate quickly on a more sophisticated mode stack design. The design here is oriented around TorchFunctionMode objects, which are chained together via object composition. IF this design is accepted, we should also apply it to the existing `__torch_dispatch__` mode. **What is `__torch_function__` mode?** Check the documentation for TorchFunctionMode in torch/overrides.py **How does it work?** There are two parts to the implementation, a simple but non-compositional C++ API, and a compositional Python API built on top of the C++ API. In C++: * We add a new Python mode object to PythonTorchFunctionTLS; the object doesn't have any type constraints but we will assume that it has a `__torch_function__` method defined on it. (Unlike `python_eager_mode`, this object need not be a class; it can be an honest to goodness instance of a class.) * We modify `has_torch_function` to return true if the mode is non-null. In some cases, this is wrong (specifically when we are accumulating overloaded arguments), so those get a special argument to say "ignore modes when testing." I didn't do a very careful audit of all the sites so this is probably where we mostly likely have bugs in the implementation. * When handling `__torch_function__`, we check if this mode is set. If it is, we first call `__torch_function__` on the mode to attempt to handle the operation (that is to say, we process modes first, and then concrete values)--unusually, we also reset the mode to be empty before calling back into `__torch_function__` (see #75166 for back story). If the mode returns NotImplemented we attempt to handle the operator with the subclasses instead. I fixed a bug in the error reporting where pointer values rather than actual reprs were printed. In Python: * We add a new class hierarchy TorchFunctionMode, representing objects that can handle `__torch_function__` and can chain with each other (allowing for nested modes). These classes have a metaclass TorchFunctionModeMeta which is (1) responsible for adding chaining functionality, without requiring users to manually record `inner` themselves, and (2) responsible for resetting the current mode (which is null at the time `__torch_function__` is called) to the mode of the *inner* mode class. We represent `__torch_function__` handlers as objects, not classes, so that a single handler can be associated with arbitrary state (most notably, the `inner` property; but you can define a parametrized modes--see example in test suite). * We then provide two user-visible functions for working with the mode stack: `push_torch_function_mode` is the recommended function which is compositional (allowing for nested modes) as well as `enable_torch_function_mode` which is non-compositional (good for simple use cases, or swapping out a known mode with another mode). Check the docblocks in the PR for more details. * For symmetry with `__torch_dispatch__` mode, we also accept a Tensor subclass as a mode passed to the argument; this works in much the same way a non-compositional TorchFunctionMode would be applied. **Where there alternatives?** * `push_torch_function_mode` could have been implemented as a C++ level concept, representing the stack explicitly in C++. This would have involved writing quite a lot of complex C++ binding code, and when I determined that I could implement all of this in Python using only two simple bindings I opted for this strategy for development speed. It may be useful to do it this way in the future for performance reasons. * This PR introduces a new concept, TorchFunctionMode, whereas the existing `__torch_dispatch__` mode reused the class object for specifying the mode. I think introducing a new concept here is warranted for the following reasons: (1) `enable_python_mode` awkwardly requires you to subclass Tensor to write your mode, even if you never intend to instantiate your tensor, (2) there is no way to add extra metadata to a class without dynamically generating a class on the fly; this metadata is required for compositional use cases (aka, `inner` ala functorch), (3) even if you were able to dynamically generate the class working up and sideways the class hierarchy with super() and CurrentClass is awkward and people are unlikely to do it correctly. This is a case of favoring composition over inheritance, giving us an easier to understand API (even if it is not parsimonious). **Examples?** Check the test cases in TestTorchFunctionMode in test/test_overrides.py Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
@pytorchbot merge this |
Hey @ezyang. |
Summary: Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: #75154 Approved by: https://github.com/albanD, https://github.com/zou3519 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/31c86625cce02bbb37b03781d21dca7241bcad98 Reviewed By: b0noI Differential Revision: D35469410 Pulled By: ezyang fbshipit-source-id: a3c8e9191a2e4145b1914cbe404ff8d2b1bd0316
…torch_function_mode" `enable_python_mode`/`enable_torch_function_mode` (torch_function_mode introduced [here](#75154)) are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
…ode" `enable_python_mode`/`enable_torch_function_mode` (torch_function_mode introduced [here](#75154)) are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
…torch_function_mode" `enable_python_mode`/`enable_torch_function_mode` (torch_function_mode introduced [here](#75154)) are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
…ode" `enable_python_mode`/`enable_torch_function_mode` (torch_function_mode introduced [here](#75154)) are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
…torch_function_mode" `enable_python_mode`/`enable_torch_function_mode` (torch_function_mode introduced [here](#75154)) are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
…ode" `enable_python_mode`/`enable_torch_function_mode` (torch_function_mode introduced [here](#75154)) are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode), (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable. **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode), (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable. **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
This adds the ability to nest python_modes with `push_python_mode`, based on the torch_function design introduced [here](#75154). In order to push python modes, we needed to (1) introduce the concept of PythonMode (a la TorchFunctionMode) that defines a torch_dispatch but isn't a tensor subclass, (2) update enable_python_mode to match torch function mode (longer explanation below), and (3) introduce `push_python_mode` alongside the now updated `enable_python_mode`. Users will only be able to nest modes using push but not enable **Longer note on what it means to update enable_python_mode** `enable_python_mode`/`enable_torch_function_mode` are the simplest versions of the modes that don't support nesting but do support passing a tensor subclass as the mode. The current implementation of `enable_python_mode` didn't match the functionality of `enable_torch_function_mode` (which limited the ability to add nesting in the form of `push_python_mode` in the next PR) This makes their behaviors match which means: - the mode is disabled during the call to torch_dispatch, this removes the `no_dispatch` calls that needed to be everywhere in order to prevent infinite recursion and will be necessary for nested modes - nested modes are "allowed" (doesn't break) if the mode passed to the instance is the same (it's treated as a noop in this case). Otherwise, it will error like before - Adds some `kwargs` that let a user override a current mode to a new one (not the same as nesting because it ignores the outer context manager) - small string differences so the error messages are the same between the two [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
__torch_function__
mode bringsenable_python_mode
functionality to the__torch_function__
level of overrides (at the Python API level). I intend to use this to collect test data for validating abstract interpreters we plan to write for operators, but there are many other possible applications (for example, tracers like FX could use this to track factory function calls without monkey-patching the PyTorch namespace, and the mode could be used to vary backward semantics of functions without having to add new arguments to them, see https://github.com/albanD/subclass_zoo/blob/main/inner_autograd_tensor.py)This PR also gives a design for mode stacks which is implemented in Python (rather in some sort of C++ infrastructure). My goal here was to trade off performance so that I could simplify the C++ bindings (we only expose
_set_torch_function_mode
and_get_torch_function_mode
in the bindings) and iterate quickly on a more sophisticated mode stack design. The design here is oriented around TorchFunctionMode objects, which are chained together via object composition. IF this design is accepted, we should also apply it to the existing__torch_dispatch__
mode.What is
__torch_function__
mode? Check the documentation for TorchFunctionMode in torch/overrides.pyHow does it work? There are two parts to the implementation, a simple but non-compositional C++ API, and a compositional Python API built on top of the C++ API.
In C++:
__torch_function__
method defined on it. (Unlikepython_eager_mode
, this object need not be a class; it can be an honest to goodness instance of a class.)has_torch_function
to return true if the mode is non-null. In some cases, this is wrong (specifically when we are accumulating overloaded arguments), so those get a special argument to say "ignore modes when testing." I didn't do a very careful audit of all the sites so this is probably where we mostly likely have bugs in the implementation.__torch_function__
, we check if this mode is set. If it is, we first call__torch_function__
on the mode to attempt to handle the operation (that is to say, we process modes first, and then concrete values)--unusually, we also reset the mode to be empty before calling back into__torch_function__
(see Disable Python mode (torch dispatch mode) inside of mode-induced __torch_dispatch__call #75166 for back story). If the mode returns NotImplemented we attempt to handle the operator with the subclasses instead. I fixed a bug in the error reporting where pointer values rather than actual reprs were printed.In Python:
__torch_function__
and can chain with each other (allowing for nested modes). These classes have a metaclass TorchFunctionModeMeta which is (1) responsible for adding chaining functionality, without requiring users to manually recordinner
themselves, and (2) responsible for resetting the current mode (which is null at the time__torch_function__
is called) to the mode of the inner mode class. We represent__torch_function__
handlers as objects, not classes, so that a single handler can be associated with arbitrary state (most notably, theinner
property; but you can define a parametrized modes--see example in test suite).push_torch_function_mode
is the recommended function which is compositional (allowing for nested modes) as well asenable_torch_function_mode
which is non-compositional (good for simple use cases, or swapping out a known mode with another mode). Check the docblocks in the PR for more details.__torch_dispatch__
mode, we also accept a Tensor subclass as a mode passed to the argument; this works in much the same way a non-compositional TorchFunctionMode would be applied.Where there alternatives?
push_torch_function_mode
could have been implemented as a C++ level concept, representing the stack explicitly in C++. This would have involved writing quite a lot of complex C++ binding code, and when I determined that I could implement all of this in Python using only two simple bindings I opted for this strategy for development speed. It may be useful to do it this way in the future for performance reasons.__torch_dispatch__
mode reused the class object for specifying the mode. I think introducing a new concept here is warranted for the following reasons: (1)enable_python_mode
awkwardly requires you to subclass Tensor to write your mode, even if you never intend to instantiate your tensor, (2) there is no way to add extra metadata to a class without dynamically generating a class on the fly; this metadata is required for compositional use cases (aka,inner
ala functorch), (3) even if you were able to dynamically generate the class working up and sideways the class hierarchy with super() and CurrentClass is awkward and people are unlikely to do it correctly. This is a case of favoring composition over inheritance, giving us an easier to understand API (even if it is not parsimonious).Examples? Check the test cases in TestTorchFunctionMode in test/test_overrides.py
Signed-off-by: Edward Z. Yang ezyang@fb.com