Skip to content

Commit

Permalink
Update on "move has_torch_function to C++, and make a special case ob…
Browse files Browse the repository at this point in the history
…ject_has_torch_function"


This PR pulls `__torch_function__` checking entirely into C++, and adds a special `object_has_torch_function` method for ops which only have one arg as this lets us skip tuple construction and unpacking. We can now also do away with the Python side fast bailout for `Tensor` (e.g. `if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors)`) because they're actually slower than checking with the Python C API.

Test plan: Existing unit tests. Benchmarks are in #48966

[ghstack-poisoned]
  • Loading branch information
Taylor Robie committed Dec 16, 2020
1 parent bd0ba4b commit c062d2c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines
def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
def _has_torch_function(Iterable[Any]) -> _bool: ... # THPModule_has_torch_function
def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary
def _has_torch_function_variadic(...: Any) -> _bool: ... # THPModule_has_torch_function_variadic
def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic
def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/utils/disable_torch_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ auto check_has_torch_function(PyObject* obj) -> bool
}
} // namespace torch

inline bool _sequence_has_torch_function(PyObject* args) {
inline bool sequence_has_torch_function(PyObject* args) {
Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args);
for (Py_ssize_t i = 0; i < nargs; i++) {
PyObject* obj = PySequence_Fast_GET_ITEM(args, i);
Expand All @@ -189,7 +189,7 @@ inline bool _sequence_has_torch_function(PyObject* args) {
return false;
}

inline bool _array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) {
inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) {
for (Py_ssize_t i = 0; i < nargs; i++) {
if (torch::check_has_torch_function(args[i]))
return true;
Expand All @@ -198,18 +198,18 @@ inline bool _array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) {
}

PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) {
bool result;
bool result; // NOLINT(cppcoreguidelines-init-variables)
if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) {
// Fast path:
// If we know that we have a tuple or list, we can skip an INCREF and
// DECREF from PySequence_Fast. Core functions will always follow this
// convention (almost always tuples), and it shaves ~3.5% off the cost of
// the check.
result = _sequence_has_torch_function(arg);
result = sequence_has_torch_function(arg);
} else {
auto args = py::reinterpret_steal<py::object>(
PySequence_Fast(arg, "expected a sequence"));
result = _sequence_has_torch_function(args.ptr());
result = sequence_has_torch_function(args.ptr());
}

if (result)
Expand All @@ -226,7 +226,7 @@ PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) {
}

PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs) {
if (_array_has_torch_function(args, nargs))
if (array_has_torch_function(args, nargs))
Py_RETURN_TRUE;

Py_RETURN_FALSE;
Expand Down

0 comments on commit c062d2c

Please sign in to comment.