Skip to content

Commit

Permalink
Update on "move has_torch_function to C++, and make a special case ob…
Browse files Browse the repository at this point in the history
…ject_has_torch_function"


This PR pulls `__torch_function__` checking entirely into C++, and adds a special `object_has_torch_function` method for ops which only have one arg as this lets us skip tuple construction and unpacking. We can now also do away with the Python side fast bailout for `Tensor` (e.g. `if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors)`) because they're actually slower than checking with the Python C API.

Test plan: Existing unit tests. Benchmarks are in #48966

[ghstack-poisoned]
  • Loading branch information
Taylor Robie committed Dec 13, 2020
1 parent a39a934 commit cdb03d0
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions torch/csrc/utils/disable_torch_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,10 @@ static bool is_basic_python_type(PyTypeObject *tp)
}

inline bool has_torch_function_attr(PyObject* obj) {
auto attr = PyObject_GetAttrString(obj, "__torch_function__");
auto attr = PyObject_FastGetAttrString(obj, "__torch_function__");
return (
// NOLINTNEXTLINE(modernize-use-nullptr)
attr != NULL &&
attr != torch::disabled_torch_function);
attr.ptr() != nullptr &&
attr.ptr() != torch::disabled_torch_function);
}

inline bool sequence_has_torch_function(PyObject* args) {
Expand Down

0 comments on commit cdb03d0

Please sign in to comment.