Skip to content

Commit

Permalink
move has_torch_function to C++, and make a special case object_has_to…
Browse files Browse the repository at this point in the history
…rch_function (#48965)

Summary:
Pull Request resolved: #48965

This PR pulls `__torch_function__` checking entirely into C++, and adds a special `object_has_torch_function` method for ops which only have one arg as this lets us skip tuple construction and unpacking. We can now also do away with the Python side fast bailout for `Tensor` (e.g. `if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors)`) because they're actually slower than checking with the Python C API.

Test Plan: Existing unit tests. Benchmarks are in #48966

Reviewed By: ezyang

Differential Revision: D25590732

Pulled By: robieta

fbshipit-source-id: 6bd74788f06cdd673f3a2db898143d18c577eb42
  • Loading branch information
Taylor Robie authored and facebook-github-bot committed Jan 11, 2021
1 parent 632a440 commit d31a760
Show file tree
Hide file tree
Showing 13 changed files with 468 additions and 421 deletions.
12 changes: 9 additions & 3 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import torch
from torch import Tensor
from enum import Enum
from pathlib import Path
from typing import (Any, BinaryIO, Callable, ContextManager, Dict, Iterator, List, NamedTuple,
Optional, overload, Sequence, Tuple, TypeVar, Type, Union, Generic,
Set, AnyStr)
from typing import (
Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
Generic, Set, AnyStr)
from torch._six import inf

from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
Expand Down Expand Up @@ -498,6 +499,9 @@ def _get_qengine() -> _int: ... # THPModule_qEngine
def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine
def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines
def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
def _has_torch_function(args: Iterable[Any]) -> _bool: ... # THPModule_has_torch_function
def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary
def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic
def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
Expand Down Expand Up @@ -636,6 +640,8 @@ class _TensorBase(object):
_version: _int
_base: Optional[Tensor]
grad_fn: Any
_grad: Optional[Tensor]
_backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
${tensor_method_hints}

# Defined in torch/csrc/multiprocessing/init.cpp
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/tensor_dtypes.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_layouts.h>
#include <torch/csrc/utils/tensor_memoryformats.h>
Expand Down Expand Up @@ -629,6 +630,9 @@ static PyMethodDef TorchMethods[] = {
{"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
{"_is_torch_function_enabled", THPModule_isEnabledTorchFunction, METH_NOARGS, nullptr},
{"_disabled_torch_function_impl", THPModule_disable_torch_function, METH_VARARGS, nullptr},
{"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr},
{"_has_torch_function_unary", THPModule_has_torch_function_unary, METH_O, nullptr},
{"_has_torch_function_variadic", MAYBE_WRAP_FASTCALL(THPModule_has_torch_function_variadic), MAYBE_METH_FASTCALL, nullptr},
{nullptr, nullptr, 0, nullptr}
};

Expand Down
11 changes: 7 additions & 4 deletions torch/csrc/autograd/python_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,21 @@ THP_API PyObject *ParameterClass;
bool THPVariable_initModule(PyObject *module);
THP_API PyObject * THPVariable_Wrap(torch::autograd::Variable var);

static inline bool THPVariable_CheckExact(PyObject *obj) {
static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
// Check that a python object is a `Tensor`, but not a `Tensor` subclass.
// (A subclass could have different semantics.) The one exception is
// Parameter, which is used for Python bookkeeping but is equivalent to
// Tensor as far as C++ is concerned.
auto obj_py_type = Py_TYPE(obj);
return (
obj_py_type == (PyTypeObject*)THPVariableClass ||
obj_py_type == (PyTypeObject*)ParameterClass
tp == (PyTypeObject*)THPVariableClass ||
tp == (PyTypeObject*)ParameterClass
);
}

static inline bool THPVariable_CheckExact(PyObject *obj) {
return THPVariable_CheckTypeExact(Py_TYPE(obj));
}

inline bool THPVariable_Check(PyObject *obj)
{
return THPVariableClass && PyObject_IsInstance(obj, THPVariableClass);
Expand Down
109 changes: 109 additions & 0 deletions torch/csrc/utils/disable_torch_function.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/python_strings.h>

namespace torch {
static thread_local bool enable_torch_function = true;
Expand Down Expand Up @@ -125,3 +126,111 @@ PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) {
return result;
END_HANDLE_TH_ERRORS
}

// Makes sure that we don't check for __torch_function__ on basic Python types
static bool is_basic_python_type(PyTypeObject *tp)
{
return (
/* Basic number types */
tp == &PyBool_Type ||

tp == &PyLong_Type ||
tp == &PyFloat_Type ||
tp == &PyComplex_Type ||

/* Basic sequence types */
tp == &PyList_Type ||
tp == &PyTuple_Type ||
tp == &PyDict_Type ||
tp == &PySet_Type ||
tp == &PyFrozenSet_Type ||
tp == &PyUnicode_Type ||
tp == &PyBytes_Type ||

/* other builtins */
tp == &PySlice_Type ||
tp == Py_TYPE(Py_None) ||
tp == Py_TYPE(Py_Ellipsis) ||
tp == Py_TYPE(Py_NotImplemented) ||

PyModule_Check(tp) ||
/* sentinel to swallow trailing || */
false
);
}

inline bool has_torch_function_attr(PyObject* obj) {
auto attr = PyObject_FastGetAttrString(obj, "__torch_function__");
return (
attr.ptr() != nullptr &&
attr.ptr() != torch::disabled_torch_function);
}

namespace torch {
auto check_has_torch_function(PyObject* obj) -> bool
{
PyTypeObject *tp = Py_TYPE(obj);
return (
!THPVariable_CheckTypeExact(tp) &&
!is_basic_python_type(tp) &&
torch::torch_function_enabled() &&
has_torch_function_attr(obj)
);
}
} // namespace torch

inline bool sequence_has_torch_function(PyObject* args) {
Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args);
for (Py_ssize_t i = 0; i < nargs; i++) {
PyObject* obj = PySequence_Fast_GET_ITEM(args, i);
if (torch::check_has_torch_function(obj)) {
return true;
}
}
return false;
}

inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) {
for (Py_ssize_t i = 0; i < nargs; i++) {
if (torch::check_has_torch_function(args[i])) {
return true;
}
}
return false;
}

PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) {
bool result; // NOLINT(cppcoreguidelines-init-variables)
if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) {
// Fast path:
// If we know that we have a tuple or list, we can skip an INCREF and
// DECREF from PySequence_Fast. Core functions will always follow this
// convention (almost always tuples), and it shaves ~3.5% off the cost of
// the check.
result = sequence_has_torch_function(arg);
} else {
auto args = py::reinterpret_steal<py::object>(
PySequence_Fast(arg, "expected a sequence"));
result = sequence_has_torch_function(args.ptr());
}

if (result) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) {
// Special case `THPModule_has_torch_function` for the single arg case.
if (torch::check_has_torch_function(obj)) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs) {
if (array_has_torch_function(args, nargs)) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
6 changes: 5 additions & 1 deletion torch/csrc/utils/disable_torch_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ namespace torch {
bool torch_function_enabled();
PyObject* disabled_torch_function_impl();
void set_disabled_torch_function_impl(PyObject* value);
bool check_has_torch_function(PyObject* obj);
}

PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused);
PyObject* THPModule_DisableTorchFunctionType();
PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args);
PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args);
PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg);
PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj);
PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs);
2 changes: 1 addition & 1 deletion torch/csrc/utils/python_arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs,
}

int i = 0;
if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) {
if (self != nullptr && check_has_torch_function(self)) {
append_overloaded_arg(&this->overloaded_args, self);
}
for (auto& param : params) {
Expand Down
121 changes: 0 additions & 121 deletions torch/csrc/utils/python_arg_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,127 +673,6 @@ inline PyObject* PythonArgs::pyobject(int i) {
return args[i];
}

/*
* Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42
*
* Stripped down version of PyObject_GetAttrString,
* avoids lookups for None, tuple, and List objects,
* and doesn't create a PyErr since this code ignores it.
*
* This can be much faster then PyObject_GetAttrString where
* exceptions are not used by caller.
*
* 'obj' is the object to search for attribute.
*
* 'name' is the attribute to search for.
*
* Returns a py::object wrapping the return value. If the attribute lookup failed
* the value will be NULL.
*
*/

static py::object PyObject_FastGetAttrString(PyObject *obj, char *name)
{
PyTypeObject *tp = Py_TYPE(obj);
PyObject *res = (PyObject *)NULL;

/* Attribute referenced by (char *)name */
if (tp->tp_getattr != NULL) {
res = (*tp->tp_getattr)(obj, name);
if (res == NULL) {
PyErr_Clear();
}
}
/* Attribute referenced by (PyObject *)name */
else if (tp->tp_getattro != NULL) {
PyObject *w = THPUtils_internString(name);
if (w == NULL) {
return py::object();
}
res = (*tp->tp_getattro)(obj, w);
Py_DECREF(w);
if (res == NULL) {
PyErr_Clear();
}
}
return py::reinterpret_steal<py::object>(res);
}

// Makes sure that we don't check for __torch_function__ on basic Python types
static bool _is_basic_python_type(PyTypeObject *tp)
{
return (
/* Basic number types */
tp == &PyBool_Type ||

tp == &PyLong_Type ||
tp == &PyFloat_Type ||
tp == &PyComplex_Type ||

/* Basic sequence types */
tp == &PyList_Type ||
tp == &PyTuple_Type ||
tp == &PyDict_Type ||
tp == &PySet_Type ||
tp == &PyFrozenSet_Type ||
tp == &PyUnicode_Type ||
tp == &PyBytes_Type ||

/* other builtins */
tp == &PySlice_Type ||
tp == Py_TYPE(Py_None) ||
tp == Py_TYPE(Py_Ellipsis) ||
tp == Py_TYPE(Py_NotImplemented) ||

PyModule_Check(tp) ||
/* sentinel to swallow trailing || */
false
);
}

/*
* Lookup a special method, following the python approach of looking up
* on the type object, rather than on the instance itself.
*
* Assumes that the special method is a torch-specific one, so does not
* look at builtin types, nor does it look at a base Tensor.
*
* If no special method is found, return NULL, otherwise returns a new
* reference to the function object
*
* In future, could be made more like _Py_LookupSpecial
*/

static py::object PyTorch_LookupSpecial(PyObject *obj, char* name)
{
if (THPVariable_CheckExact(obj)) {
return py::object();
}
PyTypeObject *tp = Py_TYPE(obj);
if (_is_basic_python_type(tp)) {
return py::object();
}
return PyObject_FastGetAttrString((PyObject *)tp, name);
}

/*
* Checks if obj has a __torch_function__ implementation
*
* Returns true if an implementation is found and false otherwise
*
*/
static auto check_has_torch_function(PyObject* obj) -> bool
{
if (!torch_function_enabled()) {
return false;
}
py::object method = PyTorch_LookupSpecial(obj, "__torch_function__");
if(method.ptr() != nullptr && method.ptr() != disabled_torch_function_impl()){
return true;
}
return false;
}

/*
*
* Handle __torch_function__ overrides if we know that there are overloaded
Expand Down
27 changes: 27 additions & 0 deletions torch/csrc/utils/python_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,33 @@

#include <torch/csrc/python_headers.h>

#if PY_VERSION_HEX < 0x03070000
// METH_FASTCALL was introduced in Python 3.7, so we wrap _PyCFunctionFast
// signatures for earlier versions.

template <PyObject* (*f)(PyObject*, PyObject *const *, Py_ssize_t)>
PyObject* maybe_wrap_fastcall(PyObject *module, PyObject *args) {
return f(
module,

// _PyTuple_ITEMS
// Because this is only a compat shim for Python 3.6, we don't have
// to worry about the representation changing.
((PyTupleObject *)args)->ob_item,
PySequence_Fast_GET_SIZE(args)
);
}

#define MAYBE_METH_FASTCALL METH_VARARGS
#define MAYBE_WRAP_FASTCALL(f) maybe_wrap_fastcall<f>

#else

#define MAYBE_METH_FASTCALL METH_FASTCALL
#define MAYBE_WRAP_FASTCALL(f) (PyCFunction)(void(*)(void))f

#endif

// PyPy 3.6 does not yet have PySlice_Unpack
#if PY_VERSION_HEX < 0x03060100 || defined(PYPY_VERSION)

Expand Down

0 comments on commit d31a760

Please sign in to comment.