Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOTAutograd has high fixed overheads #122029

Open
jansel opened this issue Mar 16, 2024 · 17 comments
Open

AOTAutograd has high fixed overheads #122029

jansel opened this issue Mar 16, 2024 · 17 comments
Assignees
Labels
module: aotdispatch umbrella label for AOTAutograd issues module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jansel
Copy link
Contributor

jansel commented Mar 16, 2024

A few months ago I looked at the fixed overheads of torch.compile using this microbenchmark.

I wanted to give a quick update to that, in that since I realized I didn’t test things with requires_grad=True which triggers a bunch of extra wrapper code in AOTAutograd.

Here are, updated numbers from the last post with latest main. Sadly, we have regressed 3.7us since last time.

  • eager 4.6us
  • compiled 14.0us

As an aside, if I enable torch.inference_mode() things get a bit better:

  • eager 4.1us
  • compiled 11.1us

Now on the other end, if I add requires_grad=True to this line of the microbenchmark, I get:

  • eager 5.6us
  • compiled 39.0us

Almost 2.8x worse than compile without the need to track gradients. Note I am not calling backward, just triggering extra wrapper code to deal with backward if it does get called.

So where is it coming from? Here is the top of a cProfile (run 1000 times):

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    0.010    0.000    0.026    0.000 jit_compile_runtime_wrappers.py:505(forward)
2000/1000    0.007    0.000    0.076    0.000 eval_frame.py:427(_fn)
    20000    0.007    0.000    0.007    0.000 function.py:346(__getattribute__)
     1000    0.005    0.000    0.033    0.000 {built-in method apply}
2000/1000    0.005    0.000    0.043    0.000 utils.py:105(call_func_at_runtime_with_args)
     1000    0.005    0.000    0.005    0.000 {built-in method torch.tensor}
     1000    0.005    0.000    0.055    0.000 runtime_wrappers.py:77(runtime_wrapper)
     1000    0.002    0.000    0.039    0.000 function.py:582(apply)
     2000    0.002    0.000    0.002    0.000 eval_frame.py:140(change)
     1000    0.001    0.000    0.004    0.000 <string>:2(guard)
     1000    0.001    0.000    0.003    0.000 clph2nks5ejst7jdpwldntojs4uhnqr57fddd5w6jpqaqposu7td.py:46(call)

It looks like the worst offender is coming form jit_compile_runtime_wrappers.py. As a quick experiment I wrote a patch that just skips that entire function:

diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
index ce087d44e25..e4328df32a9 100644
--- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
@@ -504,6 +504,8 @@ def aot_dispatch_autograd(
 
         @staticmethod
         def forward(ctx, *deduped_flat_tensor_args):
+            return CompiledFunction.compiled_fw([*deduped_flat_tensor_args])[:1]
+
             args = deduped_flat_tensor_args
             if backward_state_indices:
                 bw_state = args[backward_state_indices[0]]

This improves things to:

  • compiled 26.3us

Obviously we can’t land that, but it does show the maximum benefit that could be had by optimizing that function. More than half of the overhead added by AOTAutograd is coming from that one function. Looking over that function’s code, it seems like there are a lot of checks for various corner cases that will do nothing on most invocations. In addition, nearly all of those corner case checks should be statically known at compile time. One way to speed that up would be to codegen a custom wrapper function that only includes the corner cases that are actually relevant to the function being called. Though some more basic trimming and code optimization would also likely help.

I don't plan to work on this, so this issue is up for grabs if someone is interested in some micro-optimizations.

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang

@jansel jansel added module: aotdispatch umbrella label for AOTAutograd issues oncall: pt2 labels Mar 16, 2024
jansel added a commit that referenced this issue Mar 16, 2024
Used to generate numbers in #122029

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 16, 2024
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to
34.3us.

See #122029

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Used to generate numbers in #122029

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to
34.3us.

See #122029

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to
34.3us.

See #122029

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Used to generate numbers in #122029

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Used to generate numbers in #122029

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to
34.3us.

See #122029

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to
34.3us.

See #122029

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Used to generate numbers in #122029

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Used to generate numbers in #122029

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to
34.3us.

See #122029

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to
34.3us.

See #122029

[ghstack-poisoned]
jansel added a commit that referenced this issue Mar 17, 2024
Used to generate numbers in #122029

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@zou3519 zou3519 added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label Mar 18, 2024
@anijain2305 anijain2305 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 18, 2024
pytorchmergebot pushed a commit that referenced this issue Mar 18, 2024
Used to generate numbers in #122029

Pull Request resolved: #122032
Approved by: https://github.com/yanboliang
pytorchmergebot pushed a commit that referenced this issue Mar 18, 2024
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to
34.3us.

See #122029
Pull Request resolved: #122033
Approved by: https://github.com/zou3519, https://github.com/soulitzer
ghstack dependencies: #122032
@tfiala tfiala self-assigned this Mar 22, 2024
@tfiala
Copy link
Contributor

tfiala commented Mar 22, 2024

I plan to take a swing at this.

@tfiala
Copy link
Contributor

tfiala commented Mar 24, 2024

@jansel - quick question, how do I run the cProfile called out here (basic startup question ;-) ):

So where is it coming from? Here is the top of a cProfile (run 1000 times):

I wanted to be able to replicate baseline numbers before touching anything. Thanks!

@jansel
Copy link
Contributor Author

jansel commented Mar 25, 2024

@tfiala you can do something like:

diff --git a/benchmarks/dynamo/microbenchmarks/overheads.py b/benchmarks/dynamo/microbenchmarks/overheads.py
index 687fe58cc79..9eb6d063d6f 100644
--- a/benchmarks/dynamo/microbenchmarks/overheads.py
+++ b/benchmarks/dynamo/microbenchmarks/overheads.py
@@ -5,6 +5,8 @@ import numpy as np
 
 import torch
 
+import cProfile, pstats, io
+from pstats import SortKey
 
 def add1(x):
     return x + 1
@@ -18,23 +20,40 @@ def bench(name, fn, requires_grad):
         fn(x)
     end = time.perf_counter()
 
-    results = timeit.repeat(lambda: fn(x), number=1000, repeat=1000)
-    print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)")
+    pr = cProfile.Profile()
+    pr.enable()
+    for _ in range(1000):
+        fn(x)
+    pr.disable()
+
+    # View the file below with `snakeviz overhead.prof`
+    pstats.Stats(pr).dump_stats("overhead.prof")
+
+    # Alternately format the stats to print to console
+    s = io.StringIO()
+    sortby = SortKey.CUMULATIVE
+    ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
+    ps.print_stats()
+    print(s.getvalue())
+
+    # results = timeit.repeat(lambda: fn(x), number=1000, repeat=1000)
+    # print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)")
 
 
 def main():
-    print("requires_grad=False")
-    bench("eager   ", add1, False)
-    bench("compiled", torch.compile(add1), False)
-    print()
-    print("requires_grad=True")
-    bench("eager   ", add1, True)
+    # print("requires_grad=False")
+    # bench("eager   ", add1, False)
+    # bench("compiled", torch.compile(add1), False)
+    # print()
+    # print("requires_grad=True")
+    # bench("eager   ", add1, True)
     bench("compiled", torch.compile(add1), True)
-    print()
-    print("inference_mode()")
-    with torch.inference_mode():
-        bench("eager   ", add1, False)
-        bench("compiled", torch.compile(add1), False)
+    # print()
+    # print("inference_mode()")
+    # with torch.inference_mode():
+    #     bench("eager   ", add1, False)
+    #     bench("compiled", torch.compile(add1), False)
+
 
 
 if __name__ == "__main__":

@tfiala
Copy link
Contributor

tfiala commented Mar 27, 2024

Thank you, @jansel!

@tfiala
Copy link
Contributor

tfiala commented Mar 27, 2024

Initial Run - Baseline

Ran this:

python benchmarks/dynamo/microbenchmarks/overheads.py > microbenchmarks-start.txt

using @jansel's patch here.

Got these results before commenting anything out:

head -n 15 microbenchmarks-start.txt
         81001 function calls (78001 primitive calls) in 0.035 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
2000/1000    0.005    0.000    0.035    0.000 torch/_dynamo/eval_frame.py:367(_fn)
     1000    0.000    0.000    0.028    0.000 benchmarks/dynamo/microbenchmarks/overheads.py:11(add1)
     1000    0.000    0.000    0.026    0.000 torch/_dynamo/external_utils.py:34(inner)
     1000    0.001    0.000    0.026    0.000 torch/_functorch/aot_autograd.py:913(forward)
2000/1000    0.001    0.000    0.025    0.000 torch/_functorch/_aot_autograd/utils.py:88(g)
     1000    0.002    0.000    0.024    0.000 torch/_functorch/_aot_autograd/runtime_wrappers.py:77(runtime_wrapper)
2000/1000    0.003    0.000    0.020    0.000 torch/_functorch/_aot_autograd/utils.py:105(call_func_at_runtime_with_args)
     1000    0.001    0.000    0.018    0.000 torch/autograd/function.py:590(apply)
     1000    0.002    0.000    0.015    0.000 {built-in method apply}
     1000    0.005    0.000    0.012    0.000 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:534(forward)
...

After commenting out the bulk of jit_compile_runtime_wrappers.py as described in here (i.e. eliminating the bulk of the runtime wrapper construction logic), I am getting the following results:

head -n 15 microbenchmarks-start-skip-compiled.txt
         59001 function calls (57001 primitive calls) in 0.025 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
2000/1000    0.005    0.000    0.025    0.000 torch/_dynamo/eval_frame.py:367(_fn)
     1000    0.000    0.000    0.018    0.000 benchmarks/dynamo/microbenchmarks/overheads.py:11(add1)
     1000    0.000    0.000    0.016    0.000 _dynamo/external_utils.py:34(inner)
     1000    0.001    0.000    0.015    0.000 torch/_functorch/aot_autograd.py:913(forward)
2000/1000    0.001    0.000    0.015    0.000 torch/_functorch/_aot_autograd/utils.py:88(g)
     1000    0.002    0.000    0.014    0.000 torch/_functorch/_aot_autograd/runtime_wrappers.py:77(runtime_wrapper)
     1000    0.002    0.000    0.011    0.000 torch/_functorch/_aot_autograd/utils.py:105(call_func_at_runtime_with_args)
     1000    0.001    0.000    0.008    0.000 torch/autograd/function.py:590(apply)
     1000    0.002    0.000    0.005    0.000 {built-in method apply}
     1000    0.000    0.000    0.003    0.000 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:534(forward)

Since the time that this issue was initially posted, I am seeing the delta due to jit_compile_runtime_wrappers.py to be ~10us (35.0us - 25.0us total time), after short-circuiting the bulk of the wrapper functionality as was done in the issue details. @jansel's delta was a little larger, 12.7us (39.0us - 26.3us), which could be due to some combination of related work done in the last couple weeks and slightly different hardware. I'm running these on an H100, 96-core AMD reporting 4792.78 BogoMIPS.

Given this setup, it looks like a zero-cost wrapper solution would net ~9-10us as a theoretical max per-execution runtime improvement of the compiled code.

I'll start looking at the wrapper code now. One idea, keeping it in Python, was to generate an FX graph for the wrapper, and let the FX machinery optimize out the unused cases (which we could count on being fixed at the time we would run this). This has a potentially nice property of being likely easier to serialize and minimal build pipeline complexity vs. something like a generated C++/native wrapper. It sounds like plans are in the work to do more caching sometime soon-ish to improve warm-start times, so this seems like a good place to start.

I think this approach would mean running some of the Dynamo machinery on the generated wrapper. And then executing the FX graph. I would guess this means some kind of overhead for executing an FX graph -- either the process of turning the FX graph into runnable Python bytecode or some kind of FX graph interpreter. If that's the case, then this candidate solution may have wins reduced by whatever that FX -> executable Python step incurs.

Any thoughts, comments or corrections on anything I wrote above, @jansel?

Thanks!

@tfiala
Copy link
Contributor

tfiala commented Mar 27, 2024

Related, I've been reading through the export docs here. The rationale section of that page walks through a number of comparisons that are useful to orient me. (I get that export is not what I'm looking for, but how it talks about the various ways of tracing I found to be useful background info as I start hunting through the code).

@jansel
Copy link
Contributor Author

jansel commented Mar 27, 2024

Some difference in performance might be due to #122033

Also note that times under the profiler can be distorted due to profiling costs. So when I report numbers it is not with the profiler enabled. If you run without the patch you might get different numbers.

I did some similar optimizations in this PR: https://github.com/pytorch/pytorch/pull/118070/files

Where I moved stuff off the critical path by building a dynamic list of enter_exit_hooks. Something like that might be easier to implement than going FX-based, but FX could work as well.

@tfiala
Copy link
Contributor

tfiala commented Mar 28, 2024

Thanks, @jansel!

I'll have a look at what you did there.

So when I report numbers it is not with the profiler enabled.

Ah okay. I thought the patch was needed to enable profiling. I'll look closer at that and get non-profiler-tainted numbers. Thanks for the call-out!

@tfiala
Copy link
Contributor

tfiala commented Mar 29, 2024

Addressing this comment re: baseline numbers:

Also note that times under the profiler can be distorted due to profiling costs. So when I report numbers it is not with the profiler enabled.

As of commit e63e013, the following are the baseline perf numbers I'm getting (unpatched main branch, 1x H100):

$ time python benchmarks/dynamo/microbenchmarks/overheads.py

requires_grad=False
eager    2.9us (warmup=0.0s)
compiled 9.2us (warmup=4.7s)

requires_grad=True
eager    3.8us (warmup=0.0s)
compiled 15.6us (warmup=0.1s)

inference_mode()
eager    2.8us (warmup=0.0s)
compiled 6.1us (warmup=0.0s)

python benchmarks/dynamo/microbenchmarks/overheads.py  51.45s user 12.34s system 131% cpu 48.421 total

In this setup, the autograd-checking wrapper code takes 6.4us longer (15.6us - 9.2us) for an increase of 70%. This is a lot better than the overhead as originally reported ( (39.0us - 14.0us) / 14.0us = 180% increase). I will still have a look at it, but the potential gains available here appear to have significantly reduced.

@jansel
Copy link
Contributor Author

jansel commented Mar 30, 2024

This is a CPU benchmark, so likely the difference is the devserver CPU versus my local i9-11900K. It could also be python version (3.11).

Latest numbers on my machine:

requires_grad=False
eager    4.8us (warmup=0.0s)
compiled 11.0us (warmup=6.6s)

requires_grad=True
eager    5.6us (warmup=0.0s)
compiled 29.0us (warmup=0.1s)

inference_mode()
eager    4.3us (warmup=0.0s)
compiled 8.4us (warmup=0.0s)

The exact numbers don't matter. I don't expect them to match across different machines.

@tfiala
Copy link
Contributor

tfiala commented Apr 1, 2024

Yep, just wanted to start from a known position :-)

@tfiala
Copy link
Contributor

tfiala commented Apr 1, 2024

It could also be python version (3.11).

That brings up a good question - which version of python are you using? I've done 3.10 and 3.11 but not sure what's standard. 3.12 wasn't working last time I tried but I recall hearing that is in the works.

@jansel
Copy link
Contributor Author

jansel commented Apr 1, 2024

I'm using 3.11. 3.12 support is in progress but not done yet.

@tfiala
Copy link
Contributor

tfiala commented Apr 2, 2024

Okay updated to Python 3.11 (I was on 3.10).

I stuck some trace statements in the log to better understand how the code flowed during the benchmark. That generated quite a few lines of trace output ;-)

Updated baseline numbers on my setup:

requires_grad=False
eager    2.7us (warmup=0.0s)
compiled 6.9us (warmup=4.3s)

requires_grad=True
eager    3.3us (warmup=0.0s)
compiled 13.6us (warmup=0.1s)

inference_mode()
eager    2.5us (warmup=0.0s)
compiled 5.3us (warmup=0.0s)

@tfiala
Copy link
Contributor

tfiala commented Apr 5, 2024

Getting back to this...

Looks like this is the callstack to the relevant function:

  File "benchmarks/dynamo/microbenchmarks/overheads.py", line 41, in <module>
    main()
  File "benchmarks/dynamo/microbenchmarks/overheads.py", line 32, in main
    bench("compiled", torch.compile(add1), True)
  File "benchmarks/dynamo/microbenchmarks/overheads.py", line 18, in bench
    fn(x)
  File "torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "benchmarks/dynamo/microbenchmarks/overheads.py", line 9, in add1
    def add1(x):
  File "torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "torch/_functorch/aot_autograd.py", line 917, in forward
    return compiled_fn(full_args)
  File "torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 89, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "torch/autograd/function.py", line 569, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
File "torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 544, in forward
    traceback.print_stack()

(The last one is a callstack print mechanism I added to trace where we're at when we call that code). Just getting myself oriented.

So somewhere up the callstack I'll want to find the highest place that could be setting code that is applicable and needs to be executed, in a style similar to this. Somewhere that has enough context to figure what applies to each call and can set up the relevant code fragments for execution.

@tfiala
Copy link
Contributor

tfiala commented Apr 8, 2024

Hey @jansel! As I start messing with the code, is there anything in the overheads that validates (or could validate) that the results are correct? It would be awesome as I start experimenting to know that whatever I'm trying to do is a valid transformation.

@jansel
Copy link
Contributor Author

jansel commented Apr 9, 2024

@tfiala for a basic smoke-test you could run:

pytest test/inductor/test_cpu_repro.py

(or pick any of the many test files in that directory)

For more testing, if you open a draft PR with your changes our CI system will run the entire test suite and also some full-model testing to validate the change. Doing this full run takes many hours, so it is usually easier to let CI run it in parallel. Then if anything fails you can run just the failing test locally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: aotdispatch umbrella label for AOTAutograd issues module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants