Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calling torch.ops.aten.add_ is ludicrously slow #74943

Closed
Chillee opened this issue Mar 30, 2022 · 10 comments
Closed

Calling torch.ops.aten.add_ is ludicrously slow #74943

Chillee opened this issue Mar 30, 2022 · 10 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects

Comments

@Chillee
Copy link
Contributor

Chillee commented Mar 30, 2022

馃悰 Describe the bug

Calling torch.ops.aten.add_, whether on CPU or CUDA, is orders of magnitude slower than calling Tensor.add_.

import torch
import time

inps = [torch.randn(64, 64, 50, 50, device='cuda') for _ in range(2)]

torch.cuda.synchronize()
begin = time.time()
for _ in range(5):
  inps[0].add_(inps[1])
torch.cuda.synchronize()
print(time.time()-begin)

torch.cuda.synchronize()
begin = time.time()
for _ in range(5):
  torch.ops.aten.add_(*inps)
torch.cuda.synchronize()
print(time.time()-begin)

results in

0.005440473556518555
0.7287805080413818

Getting a profile reveals that the tensor appears to be getting printed somehow

from torch.profiler import profile, record_function, ProfilerActivity

inps = [torch.randn(64, 64, 50, 50, device='cpu') for _ in range(2)]

with profile(activities=[ProfilerActivity.CPU],
        profile_memory=True, record_shapes=True) as prof:
  torch.ops.aten.add_(*inps)
print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=30))

image

cc: @jansel, @anijain2305, @albanD , @Anja

Versions

N/A

cc @albanD @zou3519

@Chillee
Copy link
Contributor Author

Chillee commented Mar 30, 2022

From @albanD

When I looked at the jit code with Anjali, the way they find the overload is by doing something like

for overload in all_overloads:
  try:
    inputs = overload.parse_args(inp)
  except:
    pass
  else:
    break
  call_function(inputs)

Hmm.. perhaps the issue here is that we try to parse_args for the inputs, it throws an error that prints out the tensor, and then gets caught?

@ngimel ngimel added module: __torch_dispatch__ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 30, 2022
@zou3519
Copy link
Contributor

zou3519 commented Mar 30, 2022

Are the overloaded variants (torch.ops.aten.add_.Tensor) faster? Does it matter that torch.ops.aten.add_ is slow? (AOTAutograd uses the overloaded variants, right?)

If the performance does matter then we should consider overhauling the overload resolution mechanism (or having new python bindings into torch.ops.aten) because this is not the first time it has caused us problems (remember TorchScript registering extra overloads that caused incorrect behavior?)

@gchanan
Copy link
Contributor

gchanan commented Mar 31, 2022

I had the same question as @zou3519 -- do we actually need to use these bindings for any non-TorchScript use case?

@albanD
Copy link
Collaborator

albanD commented Mar 31, 2022

From @ngimel comments offline:

In 2 places in pybind utils we are getting a string representation of unlucky tensor that cannot be shoehorned into schema, with all the slicing that's required for that, this diff kinda fixes it

diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp
index 8317f31d9b..59f7085c25 100644
--- a/torch/csrc/jit/python/pybind_utils.cpp
+++ b/torch/csrc/jit/python/pybind_utils.cpp
@@ -302,7 +302,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
         return static_cast<c10::complex<double>>(c_obj);
       } else {
         throw py::cast_error(
-            c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
+            c10::str("Cannot cast arg to ", type->repr_str()));
       }
     }
     case TypeKind::RRefType: {
diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h
index 942cf67618..1c946d981e 100644
--- a/torch/csrc/jit/python/pybind_utils.h
+++ b/torch/csrc/jit/python/pybind_utils.h
@@ -641,8 +641,8 @@ inline IValue argumentToIValue(
         schema.formatTypeMismatchMsg(
             argument,
             friendlyTypeName(object),
-            argumentPosition,
-            py::repr(object)),
+            argumentPosition),//
+//            py::repr(object)),
         "\nCast error details: ",
         error.what()));
   } catch (const py::error_already_set& error) {

@albanD
Copy link
Collaborator

albanD commented Mar 31, 2022

do we actually need to use these bindings for any non-TorchScript use case?

I don't think we need these anymore now that torch_dispatch moved to use the overload specific functions.

@Chillee
Copy link
Contributor Author

Chillee commented Mar 31, 2022

@gchanan Yeah... it's possible we can just always use the overloaded versions.

It is a little bit finicky though, right now, since we can't torchscript the overloads.

Also, if we go down this path, we should probably just make the non-overloaded version not callable?

@ngimel
Copy link
Collaborator

ngimel commented Mar 31, 2022

Making non-overloaded version non-callable would be very bc-breaking? But luckily not many users are using torch.ops

@albanD albanD added oncall: jit Add this issue/PR to JIT oncall triage queue and removed module: __torch_dispatch__ labels Apr 25, 2022
@github-actions github-actions bot added this to Need triage in JIT Triage Apr 25, 2022
@albanD
Copy link
Collaborator

albanD commented Apr 25, 2022

torch_dispatch is now using the direct overloads and so doesn't have this issue.

@Chillee
Copy link
Contributor Author

Chillee commented Apr 16, 2024

Resolved now - if anything, the torch.ops.aten.add_ call is now faster.

@Chillee Chillee closed this as completed Apr 16, 2024
JIT Triage automation moved this from Need triage to Done Apr 16, 2024
@albanD
Copy link
Collaborator

albanD commented Apr 16, 2024

@Chillee ho is it? How?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
JIT Triage
  
Done
Development

No branches or pull requests

6 participants