Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move has_torch_function to C++, and make a special case object_has_torch_function #48965

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
bfeb122
move has_torch_function to C++, and make a special case object_has_to…
Dec 7, 2020
a4ccaae
Update on "move has_torch_function to C++, and make a special case ob…
Dec 8, 2020
5e6c487
Update on "move has_torch_function to C++, and make a special case ob…
Dec 8, 2020
7e6d692
Update on "move has_torch_function to C++, and make a special case ob…
Dec 8, 2020
05697e7
Update on "move has_torch_function to C++, and make a special case ob…
Dec 8, 2020
41c53f1
Update on "move has_torch_function to C++, and make a special case ob…
Dec 8, 2020
94c3194
Update on "move has_torch_function to C++, and make a special case ob…
Dec 11, 2020
73ef791
Update on "move has_torch_function to C++, and make a special case ob…
Dec 11, 2020
a39a934
Update on "move has_torch_function to C++, and make a special case ob…
Dec 12, 2020
cdb03d0
Update on "move has_torch_function to C++, and make a special case ob…
Dec 13, 2020
c54569a
Update on "move has_torch_function to C++, and make a special case ob…
Dec 13, 2020
2c8c86e
Update on "move has_torch_function to C++, and make a special case ob…
Dec 16, 2020
bd0ba4b
Update on "move has_torch_function to C++, and make a special case ob…
Dec 16, 2020
c062d2c
Update on "move has_torch_function to C++, and make a special case ob…
Dec 16, 2020
166ae67
Update on "move has_torch_function to C++, and make a special case ob…
Dec 16, 2020
5061aef
Update on "move has_torch_function to C++, and make a special case ob…
Dec 16, 2020
d2828bb
Update on "move has_torch_function to C++, and make a special case ob…
Dec 17, 2020
a171679
Update on "move has_torch_function to C++, and make a special case ob…
Dec 18, 2020
71dd054
Update on "move has_torch_function to C++, and make a special case ob…
Dec 29, 2020
ba09e88
Update on "move has_torch_function to C++, and make a special case ob…
Jan 6, 2021
8c7c8f4
Update on "move has_torch_function to C++, and make a special case ob…
Jan 6, 2021
ab684f6
Update on "move has_torch_function to C++, and make a special case ob…
Jan 10, 2021
3e9b809
Update on "move has_torch_function to C++, and make a special case ob…
Jan 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 9 additions & 3 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import torch
from torch import Tensor
from enum import Enum
from pathlib import Path
from typing import (Any, BinaryIO, Callable, ContextManager, Dict, Iterator, List, NamedTuple,
Optional, overload, Sequence, Tuple, TypeVar, Type, Union, Generic,
Set, AnyStr)
from typing import (
Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
Generic, Set, AnyStr)
from torch._six import inf

from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
Expand Down Expand Up @@ -494,6 +495,9 @@ def _get_qengine() -> _int: ... # THPModule_qEngine
def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine
def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines
def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
def _has_torch_function(args: Iterable[Any]) -> _bool: ... # THPModule_has_torch_function
def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary
def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic
def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
Expand Down Expand Up @@ -632,6 +636,8 @@ class _TensorBase(object):
_version: _int
_base: Optional[Tensor]
grad_fn: Any
_grad: Optional[Tensor]
_backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
${tensor_method_hints}

# Defined in torch/csrc/multiprocessing/init.cpp
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/tensor_dtypes.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_layouts.h>
#include <torch/csrc/utils/tensor_memoryformats.h>
Expand Down Expand Up @@ -629,6 +630,9 @@ static PyMethodDef TorchMethods[] = {
{"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
{"_is_torch_function_enabled", THPModule_isEnabledTorchFunction, METH_NOARGS, nullptr},
{"_disabled_torch_function_impl", THPModule_disable_torch_function, METH_VARARGS, nullptr},
{"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr},
{"_has_torch_function_unary", THPModule_has_torch_function_unary, METH_O, nullptr},
{"_has_torch_function_variadic", MAYBE_WRAP_FASTCALL(THPModule_has_torch_function_variadic), MAYBE_METH_FASTCALL, nullptr},
{nullptr, nullptr, 0, nullptr}
};

Expand Down
11 changes: 7 additions & 4 deletions torch/csrc/autograd/python_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ THP_API PyObject *ParameterClass;
bool THPVariable_initModule(PyObject *module);
THP_API PyObject * THPVariable_Wrap(torch::autograd::Variable var);

static inline bool THPVariable_CheckExact(PyObject *obj) {
auto obj_py_type = Py_TYPE(obj);
static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
return (
obj_py_type == (PyTypeObject*)THPVariableClass ||
obj_py_type == (PyTypeObject*)ParameterClass
tp == (PyTypeObject*)THPVariableClass ||
tp == (PyTypeObject*)ParameterClass
);
}

inline bool THPVariable_CheckExact(PyObject *obj) {
return THPVariable_CheckTypeExact(Py_TYPE(obj));
}

inline bool THPVariable_Check(PyObject *obj)
{
return THPVariableClass && PyObject_IsInstance(obj, THPVariableClass);
Expand Down
109 changes: 109 additions & 0 deletions torch/csrc/utils/disable_torch_function.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/python_strings.h>

namespace torch {
static thread_local bool enable_torch_function = true;
Expand Down Expand Up @@ -125,3 +126,111 @@ PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) {
return result;
END_HANDLE_TH_ERRORS
}

// Makes sure that we don't check for __torch_function__ on basic Python types
static bool is_basic_python_type(PyTypeObject *tp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function looks unchanged but the functions below seem to have been changed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this one was just moved from torch/csrc/utils/python_arg_parser.h The functions below were modified to:
a) Bail out faster. (e.g. don't do attr checks on known types)
b) More efficiently handle checking of multiple Python values, which generally means trying to be as lazy as possible with Python containers. (e.g. PySequence_Fast does an extra refcount bump and decref because it has no way of knowing that we'll keep args alive until it's done.)

{
return (
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
/* Basic number types */
tp == &PyBool_Type ||

tp == &PyLong_Type ||
tp == &PyFloat_Type ||
tp == &PyComplex_Type ||

/* Basic sequence types */
tp == &PyList_Type ||
tp == &PyTuple_Type ||
tp == &PyDict_Type ||
tp == &PySet_Type ||
tp == &PyFrozenSet_Type ||
tp == &PyUnicode_Type ||
tp == &PyBytes_Type ||

/* other builtins */
tp == &PySlice_Type ||
tp == Py_TYPE(Py_None) ||
tp == Py_TYPE(Py_Ellipsis) ||
tp == Py_TYPE(Py_NotImplemented) ||

PyModule_Check(tp) ||
/* sentinel to swallow trailing || */
false
);
}

inline bool has_torch_function_attr(PyObject* obj) {
auto attr = PyObject_FastGetAttrString(obj, "__torch_function__");
return (
attr.ptr() != nullptr &&
attr.ptr() != torch::disabled_torch_function);
}

namespace torch {
auto check_has_torch_function(PyObject* obj) -> bool
{
PyTypeObject *tp = Py_TYPE(obj);
return (
!THPVariable_CheckTypeExact(tp) &&
!is_basic_python_type(tp) &&
torch::torch_function_enabled() &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, I would have expected torch::torch_function_enabled to be the first thing to test. Is it more expensive than I thought?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason is that I expect !THPVariable_CheckTypeExact(tp) && !is_basic_python_type(tp) to be false in most cases (particularly in normal eager use) while torch::torch_function_enabled() is normally true, so it was a question of likely_false && likely_true instead of the other way around. I think I checked and it saved an instruction or two, although between branch predictors and instruction parallelism I don't know if it actually matters on a real chip. And because sequence_has_torch_function calls has_torch_function, you could wind up doing the check a whole two to three times!!! (Alas, we're waaaaayyyy under what I could hope to A/B with wall time here.) It probably doesn't matter, but this part of the path is so hot that I'm paranoid.

has_torch_function_attr(obj)
);
}
} // namespace torch

inline bool sequence_has_torch_function(PyObject* args) {
Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args);
for (Py_ssize_t i = 0; i < nargs; i++) {
PyObject* obj = PySequence_Fast_GET_ITEM(args, i);
if (torch::check_has_torch_function(obj)) {
return true;
}
}
return false;
}

inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) {
for (Py_ssize_t i = 0; i < nargs; i++) {
if (torch::check_has_torch_function(args[i])) {
return true;
}
}
return false;
}

PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) {
bool result; // NOLINT(cppcoreguidelines-init-variables)
if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) {
// Fast path:
// If we know that we have a tuple or list, we can skip an INCREF and
// DECREF from PySequence_Fast. Core functions will always follow this
// convention (almost always tuples), and it shaves ~3.5% off the cost of
// the check.
result = sequence_has_torch_function(arg);
} else {
auto args = py::reinterpret_steal<py::object>(
PySequence_Fast(arg, "expected a sequence"));
result = sequence_has_torch_function(args.ptr());
}

if (result) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) {
// Special case `THPModule_has_torch_function` for the single arg case.
if (torch::check_has_torch_function(obj)) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs) {
if (array_has_torch_function(args, nargs)) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
6 changes: 5 additions & 1 deletion torch/csrc/utils/disable_torch_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ namespace torch {
bool torch_function_enabled();
PyObject* disabled_torch_function_impl();
void set_disabled_torch_function_impl(PyObject* value);
bool check_has_torch_function(PyObject* obj);
}

PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused);
PyObject* THPModule_DisableTorchFunctionType();
PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args);
PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args);
PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg);
PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj);
PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs);
2 changes: 1 addition & 1 deletion torch/csrc/utils/python_arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ bool is_float_or_complex_list(PyObject* obj) {
}

auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
if (size > 0) {
if (size > 0) {
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
return false;
Expand Down
121 changes: 0 additions & 121 deletions torch/csrc/utils/python_arg_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -656,127 +656,6 @@ inline PyObject* PythonArgs::pyobject(int i) {
return args[i];
}

/*
* Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42
*
* Stripped down version of PyObject_GetAttrString,
* avoids lookups for None, tuple, and List objects,
* and doesn't create a PyErr since this code ignores it.
*
* This can be much faster then PyObject_GetAttrString where
* exceptions are not used by caller.
*
* 'obj' is the object to search for attribute.
*
* 'name' is the attribute to search for.
*
* Returns a py::object wrapping the return value. If the attribute lookup failed
* the value will be NULL.
*
*/

static py::object PyObject_FastGetAttrString(PyObject *obj, char *name)
{
PyTypeObject *tp = Py_TYPE(obj);
PyObject *res = (PyObject *)NULL;

/* Attribute referenced by (char *)name */
if (tp->tp_getattr != NULL) {
res = (*tp->tp_getattr)(obj, name);
if (res == NULL) {
PyErr_Clear();
}
}
/* Attribute referenced by (PyObject *)name */
else if (tp->tp_getattro != NULL) {
PyObject *w = THPUtils_internString(name);
if (w == NULL) {
return py::object();
}
res = (*tp->tp_getattro)(obj, w);
Py_DECREF(w);
if (res == NULL) {
PyErr_Clear();
}
}
return py::reinterpret_steal<py::object>(res);
}

// Makes sure that we don't check for __torch_function__ on basic Python types
static bool _is_basic_python_type(PyTypeObject *tp)
{
return (
/* Basic number types */
tp == &PyBool_Type ||

tp == &PyLong_Type ||
tp == &PyFloat_Type ||
tp == &PyComplex_Type ||

/* Basic sequence types */
tp == &PyList_Type ||
tp == &PyTuple_Type ||
tp == &PyDict_Type ||
tp == &PySet_Type ||
tp == &PyFrozenSet_Type ||
tp == &PyUnicode_Type ||
tp == &PyBytes_Type ||

/* other builtins */
tp == &PySlice_Type ||
tp == Py_TYPE(Py_None) ||
tp == Py_TYPE(Py_Ellipsis) ||
tp == Py_TYPE(Py_NotImplemented) ||

PyModule_Check(tp) ||
/* sentinel to swallow trailing || */
false
);
}

/*
* Lookup a special method, following the python approach of looking up
* on the type object, rather than on the instance itself.
*
* Assumes that the special method is a torch-specific one, so does not
* look at builtin types, nor does it look at a base Tensor.
*
* If no special method is found, return NULL, otherwise returns a new
* reference to the function object
*
* In future, could be made more like _Py_LookupSpecial
*/

static py::object PyTorch_LookupSpecial(PyObject *obj, char* name)
{
if (THPVariable_CheckExact(obj)) {
return py::object();
}
PyTypeObject *tp = Py_TYPE(obj);
if (_is_basic_python_type(tp)) {
return py::object();
}
return PyObject_FastGetAttrString((PyObject *)tp, name);
}

/*
* Checks if obj has a __torch_function__ implementation
*
* Returns true if an implementation is found and false otherwise
*
*/
static auto check_has_torch_function(PyObject* obj) -> bool
{
if (!torch_function_enabled()) {
return false;
}
py::object method = PyTorch_LookupSpecial(obj, "__torch_function__");
if(method.ptr() != nullptr && method.ptr() != disabled_torch_function_impl()){
return true;
}
return false;
}

/*
*
* Handle __torch_function__ overrides if we know that there are overloaded
Expand Down
27 changes: 27 additions & 0 deletions torch/csrc/utils/python_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,33 @@

#include <torch/csrc/python_headers.h>

#if PY_VERSION_HEX < 0x03070000
// METH_FASTCALL was introduced in Python 3.7, so we wrap _PyCFunctionFast
// signatures for earlier versions.

template <PyObject* (*f)(PyObject*, PyObject *const *, Py_ssize_t)>
PyObject* maybe_wrap_fastcall(PyObject *module, PyObject *args) {
return f(
module,

// _PyTuple_ITEMS
// Because this is only a compat shim for Python 3.6, we don't have
// to worry about the representation changing.
((PyTupleObject *)args)->ob_item,
PySequence_Fast_GET_SIZE(args)
);
}

#define MAYBE_METH_FASTCALL METH_VARARGS
#define MAYBE_WRAP_FASTCALL(f) maybe_wrap_fastcall<f>

#else

#define MAYBE_METH_FASTCALL METH_FASTCALL
#define MAYBE_WRAP_FASTCALL(f) (PyCFunction)(void(*)(void))f

#endif

// PyPy 3.6 does not yet have PySlice_Unpack
#if PY_VERSION_HEX < 0x03060100 || defined(PYPY_VERSION)

Expand Down