Skip to content

Commit

Permalink
Update on "move has_torch_function to C++, and make a special case ob…
Browse files Browse the repository at this point in the history
…ject_has_torch_function"


This PR pulls `__torch_function__` checking entirely into C++, and adds a special `object_has_torch_function` method for ops which only have one arg as this lets us skip tuple construction and unpacking. We can now also do away with the Python side fast bailout for `Tensor` (e.g. `if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors)`) because they're actually slower than checking with the Python C API.

Test plan: Existing unit tests. Benchmarks are in #48966

Differential Revision: [D25590732](https://our.internmc.facebook.com/intern/diff/D25590732)

[ghstack-poisoned]
  • Loading branch information
Taylor Robie committed Jan 10, 2021
1 parent ab684f6 commit 3e9b809
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,7 +1711,7 @@ def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3,
is performed. This is useful for preventing data type overflows. Default: None.
"""
if not torch.jit.is_scripting():
if has_torch_function_unary(logits):
if has_torch_function_unary(input):
return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
if dim is None:
dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel)
Expand Down

0 comments on commit 3e9b809

Please sign in to comment.