Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions docs/source/notes/extending.rst
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ Some important implications of this implementation are:
- Our native functions are lazily populated as ``torch.ops.{namespace}.{func_name}.{overload_name}`` as callable Python objects to enable easily interacting with them from Python. The ``func`` object given to ``__torch_dispatch__`` is always an entry from this namespace. This namespace can be used to directly call native ops and bypass the usual Python API and binding code.


In a similar way where ``__torch_function__`` is able to interpose on all of torch's Python API and Tensor methods, ``__torch_dispatch__`` is able to intercept all calls into the aten native API. Note that all methods on Tensors are converted into function calls before entering the dispatcher and thus will appear as function calls here: ``torch.add(a, 2)`` and ``a + 2`` will lead to exactly the same aten call.
In a similar way where ``__torch_function__`` is able to interpose on all of torch's Python API and Tensor methods, ``__torch_dispatch__`` is able intercepting all calls into the aten native API. Note that all methods on Tensors are converted into function calls before entering the dispatcher and thus will appear as function calls here: ``torch.add(a, 2)`` and ``a + 2`` will lead to exactly the same aten call.
Most of these functions are defined in ``native_functions.yaml`` which specifies the properties of these functions as well as their backend implementation. Their implementation alongside specified features are then automatically registered via codegen.
Some more exotic functions or features are also registered in other places in the C++ codebase or in user-defined C++ extensions.

Expand All @@ -903,11 +903,70 @@ You can find many examples of ``__torch_dispatch__``-based subclasses in the [su
Extending all :mod:`torch` API with Modes
-----------------------------------------

TODO Q: what about functions that don't take tensor inputs?
Unfortunately, there are functions that do not take Tensor inputs. This means that the subclass approach described above cannot be used to override the behavior of all of PyTorch's functions. Also, if the use case requires to intercept every function call, changing every Tensor to be a subclass can be overly intrusive.

To address this use case, we introduced the concept of "Mode". These exist for ``__torch_function__`` and ``__torch_dispatch__`` overrides, are created by subclassing respectively :class:`torch.overrides.TorchFunctionMode` and :class:`torch.utils._python_dispatch.TorchDispatchMode`, and are used as a context manager.

To simplify the description of how it interacts with subclasses and other modes, whenever the context manager for a mode is entered, every function behaves as if there was an extra Tensor argument at the beginning of the argument list with the mode as a subclass.
This means in particular that all modes handlers will be called before any subclass handler and that modes corresponding to the inner context manager will always run first.

It is also important to note that within a given mode handler, this specific mode is disabled and can be re-enabled manually by doing ``with self:``.

Here is an example that shows logging modes of each type::

import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode

class FunctionLog(TorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
return func(*args, **kwargs or {})

class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
return func(*args, **kwargs or {})

def f():
a = torch.rand(10, requires_grad=True)
b = a * 2
b.sum().backward()

print("TorchFunctionMode logging:")
with FunctionLog():
f()

print("TorchDispatchMode logging:")
with DispatchLog():
f()

Which prints the following, with extra comments::

TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})

TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})

TODO Intro to the concept of mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do more here to explain when someone would use torch function vs torch dispatch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two previous sections should be describing these in details already.
And I do plan to add an intro with the big "when to use what" diagram. I will do that in a follow up!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes


TODO Example of logging mode

Writing custom C++ extensions
-----------------------------
Expand Down