-
Couldn't load subscription status.
- Fork 25.7k
Add the Mode section in the extending doc #110073
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -890,7 +890,7 @@ Some important implications of this implementation are: | |
| - Our native functions are lazily populated as ``torch.ops.{namespace}.{func_name}.{overload_name}`` as callable Python objects to enable easily interacting with them from Python. The ``func`` object given to ``__torch_dispatch__`` is always an entry from this namespace. This namespace can be used to directly call native ops and bypass the usual Python API and binding code. | ||
|
|
||
|
|
||
| In a similar way where ``__torch_function__`` is able to interpose on all of torch's Python API and Tensor methods, ``__torch_dispatch__`` is able to intercept all calls into the aten native API. Note that all methods on Tensors are converted into function calls before entering the dispatcher and thus will appear as function calls here: ``torch.add(a, 2)`` and ``a + 2`` will lead to exactly the same aten call. | ||
| In a similar way where ``__torch_function__`` is able to interpose on all of torch's Python API and Tensor methods, ``__torch_dispatch__`` is able intercepting all calls into the aten native API. Note that all methods on Tensors are converted into function calls before entering the dispatcher and thus will appear as function calls here: ``torch.add(a, 2)`` and ``a + 2`` will lead to exactly the same aten call. | ||
| Most of these functions are defined in ``native_functions.yaml`` which specifies the properties of these functions as well as their backend implementation. Their implementation alongside specified features are then automatically registered via codegen. | ||
| Some more exotic functions or features are also registered in other places in the C++ codebase or in user-defined C++ extensions. | ||
|
|
||
|
|
@@ -903,11 +903,70 @@ You can find many examples of ``__torch_dispatch__``-based subclasses in the [su | |
| Extending all :mod:`torch` API with Modes | ||
| ----------------------------------------- | ||
|
|
||
| TODO Q: what about functions that don't take tensor inputs? | ||
| Unfortunately, there are functions that do not take Tensor inputs. This means that the subclass approach described above cannot be used to override the behavior of all of PyTorch's functions. Also, if the use case requires to intercept every function call, changing every Tensor to be a subclass can be overly intrusive. | ||
|
|
||
| To address this use case, we introduced the concept of "Mode". These exist for ``__torch_function__`` and ``__torch_dispatch__`` overrides, are created by subclassing respectively :class:`torch.overrides.TorchFunctionMode` and :class:`torch.utils._python_dispatch.TorchDispatchMode`, and are used as a context manager. | ||
|
|
||
| To simplify the description of how it interacts with subclasses and other modes, whenever the context manager for a mode is entered, every function behaves as if there was an extra Tensor argument at the beginning of the argument list with the mode as a subclass. | ||
| This means in particular that all modes handlers will be called before any subclass handler and that modes corresponding to the inner context manager will always run first. | ||
|
|
||
| It is also important to note that within a given mode handler, this specific mode is disabled and can be re-enabled manually by doing ``with self:``. | ||
|
|
||
| Here is an example that shows logging modes of each type:: | ||
|
|
||
| import torch | ||
| from torch.overrides import TorchFunctionMode, resolve_name | ||
| from torch.utils._python_dispatch import TorchDispatchMode | ||
|
|
||
| class FunctionLog(TorchFunctionMode): | ||
| def __torch_function__(self, func, types, args, kwargs=None): | ||
| print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})") | ||
| return func(*args, **kwargs or {}) | ||
|
|
||
| class DispatchLog(TorchDispatchMode): | ||
| def __torch_dispatch__(self, func, types, args, kwargs=None): | ||
| print(f"Dispatch Log: {func}(*{args}, **{kwargs})") | ||
| return func(*args, **kwargs or {}) | ||
|
|
||
| def f(): | ||
| a = torch.rand(10, requires_grad=True) | ||
| b = a * 2 | ||
| b.sum().backward() | ||
|
|
||
| print("TorchFunctionMode logging:") | ||
| with FunctionLog(): | ||
| f() | ||
|
|
||
| print("TorchDispatchMode logging:") | ||
| with DispatchLog(): | ||
| f() | ||
|
|
||
| Which prints the following, with extra comments:: | ||
|
|
||
| TorchFunctionMode logging: | ||
| Function Log: torch.rand(*(10,), **{'requires_grad': True}) | ||
| Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624, | ||
| 0.5970], requires_grad=True), 2), **None) | ||
| Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249, | ||
| 1.1939], grad_fn=<MulBackward0>),), **None) | ||
| # Note that at the python level, we only see the call to backward but not what happens in the autograd engine. | ||
| Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None}) | ||
|
|
||
| TorchDispatchMode logging: | ||
| # Here the requires_grad flag from autograd is removed while default arguments were populated. | ||
| Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False}) | ||
| Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948, | ||
| 0.6023], requires_grad=True), 2), **{}) | ||
| Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897, | ||
| 1.2046], grad_fn=<MulBackward0>),), **{}) | ||
| # Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient. | ||
| Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format}) | ||
| # This is the backward of the sum | ||
| Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{}) | ||
| Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{}) | ||
| Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{}) | ||
| Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{}) | ||
|
|
||
| TODO Intro to the concept of mode | ||
|
||
|
|
||
| TODO Example of logging mode | ||
|
|
||
| Writing custom C++ extensions | ||
| ----------------------------- | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.