Skip to content

Commit

Permalink
Make functions documented in notes/extending public.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Nov 24, 2020
1 parent 933b837 commit 32eed06
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
12 changes: 6 additions & 6 deletions docs/source/notes/extending.rst
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,8 @@ functions in that namespace that can be overriden. For example, let's print the
names of the first 5 functions in ``torch.nn.functional`` that can be
overriden::

>>> from torch.overrides import _get_overridable_functions
>>> func_dict = _get_overridable_functions()
>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
Expand All @@ -631,16 +631,16 @@ functions are most useful to use with ``inspect`` to analyze the function
signature of the original ``PyTorch`` function::

>>> import inspect
>>> from torch.overrides import _get_testing_overrides
>>> override_dict = _get_testing_overrides()
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>

Finally, ``torch.overrides._get_ignored_functions`` returns a tuple of functions
Finally, ``torch.overrides.get_ignored_functions`` returns a tuple of functions
that explicitly cannot be overrided by ``__torch_function__``. This list can be
useful to confirm that a function that isn't present in the dictionary returned
by ``_get_overridable_functions`` cannot be overriden.
by ``get_overridable_functions`` cannot be overriden.


Writing custom C++ extensions
Expand Down
6 changes: 6 additions & 0 deletions docs/source/torch.overrides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ protocol. See :ref:`extending-torch` for more detail on the
Functions
~~~~~~~~~

.. autofunction:: get_ignored_functions

.. autofunction:: get_overridable_functions

.. autofunction:: get_testing_overrides

.. autofunction:: handle_torch_function

.. autofunction:: has_torch_function
Expand Down
10 changes: 5 additions & 5 deletions test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch.overrides import (
handle_torch_function,
has_torch_function,
_get_overridable_functions,
_get_testing_overrides,
get_overridable_functions,
get_testing_overrides,
is_tensor_method_or_property
)

Expand Down Expand Up @@ -314,8 +314,8 @@ def decorator(func):
def generate_tensor_like_torch_implementations():
torch_vars = vars(torch)
untested_funcs = []
testing_overrides = _get_testing_overrides()
for namespace, funcs in _get_overridable_functions().items():
testing_overrides = get_testing_overrides()
for namespace, funcs in get_overridable_functions().items():
for func in funcs:
if func not in testing_overrides:
untested_funcs.append("{}.{}".format(namespace, func.__name__))
Expand Down Expand Up @@ -620,7 +620,7 @@ def test(self):

return test

for func, override in _get_testing_overrides().items():
for func, override in get_testing_overrides().items():
test_method = test_generator(func, override)
if func.__name__ == "__get__":
# Note: properties and __get__
Expand Down
23 changes: 13 additions & 10 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from torch._C import _is_torch_function_enabled, _disabled_torch_function_impl

__all__ = [
"get_ignored_functions",
"get_overridable_functions",
"get_testing_overrides",
"handle_torch_function",
"has_torch_function",
"is_tensor_like",
Expand All @@ -39,7 +42,7 @@
]

@functools.lru_cache(None)
def _get_ignored_functions() -> Set[Callable]:
def get_ignored_functions() -> Set[Callable]:
"""
Return public functions that cannot be overridden by ``__torch_function__``.
Expand Down Expand Up @@ -202,7 +205,7 @@ def _get_ignored_functions() -> Set[Callable]:


@functools.lru_cache(None)
def _get_testing_overrides() -> Dict[Callable, Callable]:
def get_testing_overrides() -> Dict[Callable, Callable]:
"""Return a dict containing dummy overrides for all overridable functions
Returns
Expand All @@ -216,9 +219,9 @@ def _get_testing_overrides() -> Dict[Callable, Callable]:
Examples
--------
>>> import inspect
>>> my_lambda = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.getfullargspec(my_lambda)
FullArgSpec(...)
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
"""
# Every function in the PyTorchAPI that can be overriden needs an entry
# in this dict.
Expand Down Expand Up @@ -973,7 +976,7 @@ def _get_testing_overrides() -> Dict[Callable, Callable]:
}

ret2 = {}
ignored = _get_ignored_functions()
ignored = get_ignored_functions()

for k, v in ret.items():
# Generate methods like __add__ and add_ by default from add
Expand Down Expand Up @@ -1184,7 +1187,7 @@ def has_torch_function(relevant_args: Iterable[Any]) -> bool:
)

@functools.lru_cache(None)
def _get_overridable_functions() -> Dict[Any, List[Callable]]:
def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""List functions that are overridable via __torch_function__
Returns
Expand Down Expand Up @@ -1238,18 +1241,18 @@ def _get_overridable_functions() -> Dict[Any, List[Callable]]:
continue

# cannot be overriden by __torch_function__
if func in _get_ignored_functions():
if func in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
"but still has an explicit override")
assert func not in _get_testing_overrides(), msg.format(namespace, func.__name__)
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
continue
overridable_funcs[namespace].append(func)
return overridable_funcs

@functools.lru_cache(None)
def _get_tensor_methods() -> Set[Callable]:
""" Returns a set of the overridable methods on ``torch.Tensor`` """
overridable_funcs = _get_overridable_functions()
overridable_funcs = get_overridable_functions()
methods = set(overridable_funcs[torch.Tensor])
return methods

Expand Down

0 comments on commit 32eed06

Please sign in to comment.