From c062d2ce69d3cd74e29007a67f888859568d088f Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Wed, 16 Dec 2020 10:16:57 -0800 Subject: [PATCH] Update on "move has_torch_function to C++, and make a special case object_has_torch_function" This PR pulls `__torch_function__` checking entirely into C++, and adds a special `object_has_torch_function` method for ops which only have one arg as this lets us skip tuple construction and unpacking. We can now also do away with the Python side fast bailout for `Tensor` (e.g. `if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors)`) because they're actually slower than checking with the Python C API. Test plan: Existing unit tests. Benchmarks are in #48966 [ghstack-poisoned] --- torch/_C/__init__.pyi.in | 2 +- torch/csrc/utils/disable_torch_function.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 89d8799cb4bc..5192bc242c9b 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -497,7 +497,7 @@ def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK def _has_torch_function(Iterable[Any]) -> _bool: ... # THPModule_has_torch_function def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary -def _has_torch_function_variadic(...: Any) -> _bool: ... # THPModule_has_torch_function_variadic +def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index 13232d76594a..6dc8526e56c5 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -179,7 +179,7 @@ auto check_has_torch_function(PyObject* obj) -> bool } } // namespace torch -inline bool _sequence_has_torch_function(PyObject* args) { +inline bool sequence_has_torch_function(PyObject* args) { Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args); for (Py_ssize_t i = 0; i < nargs; i++) { PyObject* obj = PySequence_Fast_GET_ITEM(args, i); @@ -189,7 +189,7 @@ inline bool _sequence_has_torch_function(PyObject* args) { return false; } -inline bool _array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) { +inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) { for (Py_ssize_t i = 0; i < nargs; i++) { if (torch::check_has_torch_function(args[i])) return true; @@ -198,18 +198,18 @@ inline bool _array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) { } PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) { - bool result; + bool result; // NOLINT(cppcoreguidelines-init-variables) if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) { // Fast path: // If we know that we have a tuple or list, we can skip an INCREF and // DECREF from PySequence_Fast. Core functions will always follow this // convention (almost always tuples), and it shaves ~3.5% off the cost of // the check. - result = _sequence_has_torch_function(arg); + result = sequence_has_torch_function(arg); } else { auto args = py::reinterpret_steal( PySequence_Fast(arg, "expected a sequence")); - result = _sequence_has_torch_function(args.ptr()); + result = sequence_has_torch_function(args.ptr()); } if (result) @@ -226,7 +226,7 @@ PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) { } PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs) { - if (_array_has_torch_function(args, nargs)) + if (array_has_torch_function(args, nargs)) Py_RETURN_TRUE; Py_RETURN_FALSE;