-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
move has_torch_function to C++, and make a special case object_has_torch_function #48965
Changes from 19 commits
bfeb122
a4ccaae
5e6c487
7e6d692
05697e7
41c53f1
94c3194
73ef791
a39a934
cdb03d0
c54569a
2c8c86e
bd0ba4b
c062d2c
166ae67
5061aef
d2828bb
a171679
71dd054
ba09e88
8c7c8f4
ab684f6
3e9b809
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#include <torch/csrc/utils/disable_torch_function.h> | ||
#include <torch/csrc/utils/pybind.h> | ||
#include <torch/csrc/Exceptions.h> | ||
#include <torch/csrc/utils/python_strings.h> | ||
|
||
namespace torch { | ||
static thread_local bool enable_torch_function = true; | ||
|
@@ -125,3 +126,111 @@ PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) { | |
return result; | ||
END_HANDLE_TH_ERRORS | ||
} | ||
|
||
// Makes sure that we don't check for __torch_function__ on basic Python types | ||
static bool is_basic_python_type(PyTypeObject *tp) | ||
{ | ||
return ( | ||
hameerabbasi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/* Basic number types */ | ||
tp == &PyBool_Type || | ||
|
||
tp == &PyLong_Type || | ||
tp == &PyFloat_Type || | ||
tp == &PyComplex_Type || | ||
|
||
/* Basic sequence types */ | ||
tp == &PyList_Type || | ||
tp == &PyTuple_Type || | ||
tp == &PyDict_Type || | ||
tp == &PySet_Type || | ||
tp == &PyFrozenSet_Type || | ||
tp == &PyUnicode_Type || | ||
tp == &PyBytes_Type || | ||
|
||
/* other builtins */ | ||
tp == &PySlice_Type || | ||
tp == Py_TYPE(Py_None) || | ||
tp == Py_TYPE(Py_Ellipsis) || | ||
tp == Py_TYPE(Py_NotImplemented) || | ||
|
||
PyModule_Check(tp) || | ||
/* sentinel to swallow trailing || */ | ||
false | ||
); | ||
} | ||
|
||
inline bool has_torch_function_attr(PyObject* obj) { | ||
auto attr = PyObject_FastGetAttrString(obj, "__torch_function__"); | ||
return ( | ||
attr.ptr() != nullptr && | ||
attr.ptr() != torch::disabled_torch_function); | ||
} | ||
|
||
namespace torch { | ||
auto check_has_torch_function(PyObject* obj) -> bool | ||
{ | ||
PyTypeObject *tp = Py_TYPE(obj); | ||
return ( | ||
!THPVariable_CheckTypeExact(tp) && | ||
!is_basic_python_type(tp) && | ||
torch::torch_function_enabled() && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. huh, I would have expected There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main reason is that I expect |
||
has_torch_function_attr(obj) | ||
); | ||
} | ||
} // namespace torch | ||
|
||
inline bool sequence_has_torch_function(PyObject* args) { | ||
Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args); | ||
for (Py_ssize_t i = 0; i < nargs; i++) { | ||
PyObject* obj = PySequence_Fast_GET_ITEM(args, i); | ||
if (torch::check_has_torch_function(obj)) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) { | ||
for (Py_ssize_t i = 0; i < nargs; i++) { | ||
if (torch::check_has_torch_function(args[i])) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) { | ||
bool result; // NOLINT(cppcoreguidelines-init-variables) | ||
if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) { | ||
// Fast path: | ||
// If we know that we have a tuple or list, we can skip an INCREF and | ||
// DECREF from PySequence_Fast. Core functions will always follow this | ||
// convention (almost always tuples), and it shaves ~3.5% off the cost of | ||
// the check. | ||
result = sequence_has_torch_function(arg); | ||
} else { | ||
auto args = py::reinterpret_steal<py::object>( | ||
PySequence_Fast(arg, "expected a sequence")); | ||
result = sequence_has_torch_function(args.ptr()); | ||
} | ||
|
||
if (result) { | ||
Py_RETURN_TRUE; | ||
} | ||
Py_RETURN_FALSE; | ||
} | ||
|
||
PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) { | ||
// Special case `THPModule_has_torch_function` for the single arg case. | ||
if (torch::check_has_torch_function(obj)) { | ||
Py_RETURN_TRUE; | ||
} | ||
Py_RETURN_FALSE; | ||
} | ||
|
||
PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs) { | ||
if (array_has_torch_function(args, nargs)) { | ||
Py_RETURN_TRUE; | ||
} | ||
Py_RETURN_FALSE; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function looks unchanged but the functions below seem to have been changed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, this one was just moved from
torch/csrc/utils/python_arg_parser.h
The functions below were modified to:a) Bail out faster. (e.g. don't do attr checks on known types)
b) More efficiently handle checking of multiple Python values, which generally means trying to be as lazy as possible with Python containers. (e.g.
PySequence_Fast
does an extra refcount bump and decref because it has no way of knowing that we'll keep args alive until it's done.)