Skip to content

Commit

Permalink
Add documentation for torch.overrides submodule. (#48170)
Browse files Browse the repository at this point in the history
Summary:
Fixes #48087

Pull Request resolved: #48170

Reviewed By: ejguan

Differential Revision: D25220942

Pulled By: ezyang

fbshipit-source-id: a2b7f7b565f5e77173d8ce2fe9676a8131f929b6
  • Loading branch information
hameerabbasi authored and facebook-github-bot committed Nov 30, 2020
1 parent 42e7cdc commit 4e15877
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 35 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Features described in this documentation are classified by release status:
torch.hub <hub>
torch.jit <jit>
torch.linalg <linalg>
torch.overrides
nn.init
onnx
optim
Expand Down
6 changes: 4 additions & 2 deletions docs/source/notes/extending.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ This is how a ``Linear`` module can be implemented::
self.input_features, self.output_features, self.bias is not None
)

.. _extending-torch:

Extending :mod:`torch`
----------------------

Expand Down Expand Up @@ -605,7 +607,7 @@ provides a developer-facing API for ensuring full support for
changes without warning in the future.

First, to get a listing of all overridable functions, use
``torch.overrides.get_overridable_functions``. This returns a dictionary whose
``torch.overrides._get_overridable_functions``. This returns a dictionary whose
keys are namespaces in the ``PyTorch`` Python API and whose values are a list of
functions in that namespace that can be overriden. For example, let's print the
names of the first 5 functions in ``torch.nn.functional`` that can be
Expand All @@ -622,7 +624,7 @@ This listing of functions makes it possible to iterate over all overridable
functions, however in practice this is not enough to write tests for all of
these functions without laboriously and manually copying the signature of each
function for each test. To ease this process, the
``torch.overrides.get_testing_overrides`` function returns a dictionary mapping
``torch.overrides._get_testing_overrides`` function returns a dictionary mapping
overridable functions in the ``PyTorch`` API to dummy lambda functions that have
the same signature as the original function but unconditionally return -1. These
functions are most useful to use with ``inspect`` to analyze the function
Expand Down
27 changes: 27 additions & 0 deletions docs/source/torch.overrides.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
.. currentmodule:: torch.overrides

torch.overrides
---------------

This module exposes various helper functions for the ``__torch_function__``
protocol. See :ref:`extending-torch` for more detail on the
``__torch_function__`` protocol.

Functions
~~~~~~~~~

.. autofunction:: get_ignored_functions

.. autofunction:: get_overridable_functions

.. autofunction:: get_testing_overrides

.. autofunction:: handle_torch_function

.. autofunction:: has_torch_function

.. autofunction:: is_tensor_like

.. autofunction:: is_tensor_method_or_property

.. autofunction:: wrap_torch_function
18 changes: 17 additions & 1 deletion test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def test_gradcheck(self):
})

class TestNamedTuple(TestCase):
"Regression test for gh-47090"
""" Regression test for gh-47090 """
def test_max(self):
x = torch.tensor([1, 2])
xs = x.as_subclass(SubTensor2)
Expand All @@ -838,5 +838,21 @@ def test_newones(self):
n = t.new_ones((1, 2))
self.assertEqual(type(n), SubTensor2)

class TestWrapTorchFunction(TestCase):
def test_wrap_torch_function(self):
class A:
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
return -1

def dispatcher(a):
return (a,)

@torch.overrides.wrap_torch_function(dispatcher)
def f(a):
return a

self.assertEqual(f(A()), -1)

if __name__ == '__main__':
unittest.main()
175 changes: 143 additions & 32 deletions torch/overrides.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""
Python implementation of __torch_function__
Python implementation of ``__torch_function__``
While most of the torch API and handling for __torch_function__ happens
While most of the torch API and handling for ``__torch_function__`` happens
at the C++ level, some of the torch API is written in Python so we need
python-level handling for __torch_function__ overrides as well. The main
python-level handling for ``__torch_function__`` overrides as well. The main
developer-facing functionality in this file are handle_torch_function and
has_torch_function. See torch/functional.py and test/test_overrides.py
for usage examples.
NOTE: heavily inspired by NumPy's ``__array_function__`` (see:
Note
----
heavily inspired by NumPy's ``__array_function__`` (see:
https://github.com/pytorch/pytorch/issues/24015 and
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
)
Expand All @@ -28,15 +30,35 @@
import torch
from torch._C import _is_torch_function_enabled, _disabled_torch_function_impl

__all__ = [
"get_ignored_functions",
"get_overridable_functions",
"get_testing_overrides",
"handle_torch_function",
"has_torch_function",
"is_tensor_like",
"is_tensor_method_or_property",
"wrap_torch_function",
]

@functools.lru_cache(None)
def get_ignored_functions() -> Set[Callable]:
"""Return public functions that cannot be overridden by __torch_function__
"""
Return public functions that cannot be overridden by ``__torch_function__``.
Returns
-------
A tuple of functions that are publicly available in the torch API but cannot
be overridden with __torch_function__. Mostly this is because none of the
arguments of these functions are tensors or tensor-likes.
Tuple[Callable]
A tuple of functions that are publicly available in the torch API but cannot
be overridden with ``__torch_function__``. Mostly this is because none of the
arguments of these functions are tensors or tensor-likes.
Examples
--------
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
"""
Tensor = torch.Tensor
return {
Expand Down Expand Up @@ -188,12 +210,20 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Returns
-------
A dictionary that maps overridable functions in the PyTorch API to
lambda functions that have the same signature as the real function
and unconditionally return -1. These lambda functions are useful
for testing API coverage for a type that defines __torch_function__.
Dict[Callable, Callable]
A dictionary that maps overridable functions in the PyTorch API to
lambda functions that have the same signature as the real function
and unconditionally return -1. These lambda functions are useful
for testing API coverage for a type that defines ``__torch_function__``.
Examples
--------
>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
"""
# Every function in the PyTorch API that can be overriden needs an entry
# Every function in the PyTorchAPI that can be overriden needs an entry
# in this dict.
#
# Optimally we would use inspect to get the function signature and define
Expand Down Expand Up @@ -979,6 +1009,41 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
ret.update(ret2)
return ret

def wrap_torch_function(dispatcher: Callable):
"""Wraps a given function with ``__torch_function__`` -related functionality.
Parameters
----------
dispatcher: Callable
A callable that returns an iterable of Tensor-likes passed into the function.
Note
----
This decorator may reduce the performance of your code. Generally, it's enough to express
your code as a series of functions that, themselves, support __torch_function__. If you
find yourself in the rare situation where this is not the case, e.g. if you're wrapping a
low-level library and you also need it to work for Tensor-likes, then this function is available.
Examples
--------
>>> def dispatcher(a): # Must have the same signature as func
... return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a): # This will make func dispatchable by __torch_function__
... return a + 0
"""
def inner(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
if has_torch_function(relevant_args):
return handle_torch_function(func, relevant_args, *args, **kwargs)

return func(*args, **kwargs)

return wrapped

return inner

def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
"""Returns a list of arguments on which to call __torch_function__.
Expand All @@ -1004,18 +1069,15 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
Returns
-------
overloaded_types : collection of types
Types of arguments from relevant_args with __torch_function__ methods.
overloaded_args : list
Arguments from relevant_args on which to call __torch_function__
methods, in the order in which they should be called.
.. _NEP-0018:
https://numpy.org/neps/nep-0018-array-function-protocol.html
"""
# Runtime is O(num_arguments * num_unique_types)
overloaded_types = []
overloaded_types = set()
overloaded_args = []
for arg in relevant_args:
arg_type = type(arg)
Expand All @@ -1026,7 +1088,7 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
# Create lists explicitly for the first type (usually the only one
# done) to avoid setting up the iterator for overloaded_args.
if overloaded_types:
overloaded_types.append(arg_type)
overloaded_types.add(arg_type)
# By default, insert argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
Expand All @@ -1037,15 +1099,15 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
break
overloaded_args.insert(index, arg)
else:
overloaded_types = [arg_type]
overloaded_types = {arg_type}
overloaded_args = [arg]

return overloaded_args


def handle_torch_function(
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
"""Implement a function with checks for __torch_function__ overrides.
"""Implement a function with checks for ``__torch_function__`` overrides.
See torch::autograd::handle_torch_function for the equivalent of this
function in the C++ implementation.
Expand All @@ -1065,13 +1127,20 @@ def handle_torch_function(
Returns
-------
Result from calling `implementation()` or an `__torch_function__`
method, as appropriate.
object
Result from calling ``implementation`` or an ``__torch_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
Example
-------
>>> def func(a):
... if type(a) is not torch.Tensor: # This will make func dispatchable by __torch_function__
... return handle_torch_function(func, (a,), a)
... return a + 0
"""
# Check for __torch_function__ methods.
overloaded_args = _get_overloaded_args(relevant_args)
Expand All @@ -1093,7 +1162,9 @@ def handle_torch_function(
.format(func_name, [type(arg) for arg in overloaded_args]))

def has_torch_function(relevant_args: Iterable[Any]) -> bool:
"""Check for __torch_function__ implementations in the elements of an iterable
"""Check for __torch_function__ implementations in the elements of an iterable.
Considers exact ``Tensor`` s non-dispatchable.
Arguments
---------
Expand All @@ -1102,8 +1173,14 @@ def has_torch_function(relevant_args: Iterable[Any]) -> bool:
Returns
-------
True if any of the elements of relevant_args have __torch_function__
implementations, False otherwise.
bool
True if any of the elements of relevant_args have __torch_function__
implementations, False otherwise.
See Also
________
torch.is_tensor_like
Checks if something is a Tensor-like, including an exact ``Tensor``.
"""
return _is_torch_function_enabled() and any(
type(a) is not torch.Tensor and
Expand All @@ -1118,8 +1195,9 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
Returns
-------
A dictionary that maps namespaces that contain overridable functions
to functions in that namespace that can be overridden.
Dict[Any, List[Callable]]
A dictionary that maps namespaces that contain overridable functions
to functions in that namespace that can be overridden.
"""
overridable_funcs = collections.defaultdict(list)
tested_namespaces = [
Expand Down Expand Up @@ -1175,7 +1253,7 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
return overridable_funcs

@functools.lru_cache(None)
def get_tensor_methods() -> Set[Callable]:
def _get_tensor_methods() -> Set[Callable]:
""" Returns a set of the overridable methods on ``torch.Tensor`` """
overridable_funcs = get_overridable_functions()
methods = set(overridable_funcs[torch.Tensor])
Expand All @@ -1195,14 +1273,47 @@ def is_tensor_method_or_property(func: Callable) -> bool:
1. Methods/properties sometimes don't contain a `__module__` slot.
2. They require that the first passed-in argument is an instance
of ``torch.Tensor``.
Examples
--------
>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
"""
return func in get_tensor_methods() or func.__name__ == "__get__"
return func in _get_tensor_methods() or func.__name__ == "__get__"

def is_tensor_like(inp):
"""
Returns ``True`` if the passed-in input is a tensor-like.
Returns ``True`` if the passed-in input is a Tensor-like.
Currently, this occurs whenever there's a ``__torch_function__``
attribute on the input.
attribute on the type of the input.
Examples
--------
A subclass of tensor is generally a Tensor-like.
>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True
Built-in or user types aren't usually Tensor-like.
>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False
But, they can be made Tensor-like by implementing __torch_function__.
>>> class TensorLike:
... def __torch_function__(self, func, types, args, kwargs):
... return -1
>>> is_tensor_like(TensorLike())
True
"""
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
return type(inp) is torch.Tensor or hasattr(type(inp), "__torch_function__")

0 comments on commit 4e15877

Please sign in to comment.