-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AOTAutograd has high fixed overheads #122029
Comments
Used to generate numbers in #122029 [ghstack-poisoned]
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to 34.3us. See #122029 [ghstack-poisoned]
Used to generate numbers in #122029 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to 34.3us. See #122029 [ghstack-poisoned]
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to 34.3us. See #122029 [ghstack-poisoned]
Used to generate numbers in #122029 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
Used to generate numbers in #122029 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to 34.3us. See #122029 [ghstack-poisoned]
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to 34.3us. See #122029 [ghstack-poisoned]
Used to generate numbers in #122029 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
Used to generate numbers in #122029 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to 34.3us. See #122029 [ghstack-poisoned]
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to 34.3us. See #122029 [ghstack-poisoned]
Used to generate numbers in #122029 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
Used to generate numbers in #122029 Pull Request resolved: #122032 Approved by: https://github.com/yanboliang
Improves `benchmarks/dynamo/microbenchmarks/overheads.py` from 38.7us to 34.3us. See #122029 Pull Request resolved: #122033 Approved by: https://github.com/zou3519, https://github.com/soulitzer ghstack dependencies: #122032
I plan to take a swing at this. |
@jansel - quick question, how do I run the cProfile called out here (basic startup question ;-) ):
I wanted to be able to replicate baseline numbers before touching anything. Thanks! |
@tfiala you can do something like: diff --git a/benchmarks/dynamo/microbenchmarks/overheads.py b/benchmarks/dynamo/microbenchmarks/overheads.py
index 687fe58cc79..9eb6d063d6f 100644
--- a/benchmarks/dynamo/microbenchmarks/overheads.py
+++ b/benchmarks/dynamo/microbenchmarks/overheads.py
@@ -5,6 +5,8 @@ import numpy as np
import torch
+import cProfile, pstats, io
+from pstats import SortKey
def add1(x):
return x + 1
@@ -18,23 +20,40 @@ def bench(name, fn, requires_grad):
fn(x)
end = time.perf_counter()
- results = timeit.repeat(lambda: fn(x), number=1000, repeat=1000)
- print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)")
+ pr = cProfile.Profile()
+ pr.enable()
+ for _ in range(1000):
+ fn(x)
+ pr.disable()
+
+ # View the file below with `snakeviz overhead.prof`
+ pstats.Stats(pr).dump_stats("overhead.prof")
+
+ # Alternately format the stats to print to console
+ s = io.StringIO()
+ sortby = SortKey.CUMULATIVE
+ ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
+ ps.print_stats()
+ print(s.getvalue())
+
+ # results = timeit.repeat(lambda: fn(x), number=1000, repeat=1000)
+ # print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)")
def main():
- print("requires_grad=False")
- bench("eager ", add1, False)
- bench("compiled", torch.compile(add1), False)
- print()
- print("requires_grad=True")
- bench("eager ", add1, True)
+ # print("requires_grad=False")
+ # bench("eager ", add1, False)
+ # bench("compiled", torch.compile(add1), False)
+ # print()
+ # print("requires_grad=True")
+ # bench("eager ", add1, True)
bench("compiled", torch.compile(add1), True)
- print()
- print("inference_mode()")
- with torch.inference_mode():
- bench("eager ", add1, False)
- bench("compiled", torch.compile(add1), False)
+ # print()
+ # print("inference_mode()")
+ # with torch.inference_mode():
+ # bench("eager ", add1, False)
+ # bench("compiled", torch.compile(add1), False)
+
if __name__ == "__main__": |
Thank you, @jansel! |
Initial Run - BaselineRan this: python benchmarks/dynamo/microbenchmarks/overheads.py > microbenchmarks-start.txt Got these results before commenting anything out: head -n 15 microbenchmarks-start.txt
81001 function calls (78001 primitive calls) in 0.035 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
2000/1000 0.005 0.000 0.035 0.000 torch/_dynamo/eval_frame.py:367(_fn)
1000 0.000 0.000 0.028 0.000 benchmarks/dynamo/microbenchmarks/overheads.py:11(add1)
1000 0.000 0.000 0.026 0.000 torch/_dynamo/external_utils.py:34(inner)
1000 0.001 0.000 0.026 0.000 torch/_functorch/aot_autograd.py:913(forward)
2000/1000 0.001 0.000 0.025 0.000 torch/_functorch/_aot_autograd/utils.py:88(g)
1000 0.002 0.000 0.024 0.000 torch/_functorch/_aot_autograd/runtime_wrappers.py:77(runtime_wrapper)
2000/1000 0.003 0.000 0.020 0.000 torch/_functorch/_aot_autograd/utils.py:105(call_func_at_runtime_with_args)
1000 0.001 0.000 0.018 0.000 torch/autograd/function.py:590(apply)
1000 0.002 0.000 0.015 0.000 {built-in method apply}
1000 0.005 0.000 0.012 0.000 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:534(forward)
... After commenting out the bulk of head -n 15 microbenchmarks-start-skip-compiled.txt
59001 function calls (57001 primitive calls) in 0.025 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
2000/1000 0.005 0.000 0.025 0.000 torch/_dynamo/eval_frame.py:367(_fn)
1000 0.000 0.000 0.018 0.000 benchmarks/dynamo/microbenchmarks/overheads.py:11(add1)
1000 0.000 0.000 0.016 0.000 _dynamo/external_utils.py:34(inner)
1000 0.001 0.000 0.015 0.000 torch/_functorch/aot_autograd.py:913(forward)
2000/1000 0.001 0.000 0.015 0.000 torch/_functorch/_aot_autograd/utils.py:88(g)
1000 0.002 0.000 0.014 0.000 torch/_functorch/_aot_autograd/runtime_wrappers.py:77(runtime_wrapper)
1000 0.002 0.000 0.011 0.000 torch/_functorch/_aot_autograd/utils.py:105(call_func_at_runtime_with_args)
1000 0.001 0.000 0.008 0.000 torch/autograd/function.py:590(apply)
1000 0.002 0.000 0.005 0.000 {built-in method apply}
1000 0.000 0.000 0.003 0.000 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:534(forward) Since the time that this issue was initially posted, I am seeing the delta due to Given this setup, it looks like a zero-cost wrapper solution would net ~9-10us as a theoretical max per-execution runtime improvement of the compiled code. I'll start looking at the wrapper code now. One idea, keeping it in Python, was to generate an FX graph for the wrapper, and let the FX machinery optimize out the unused cases (which we could count on being fixed at the time we would run this). This has a potentially nice property of being likely easier to serialize and minimal build pipeline complexity vs. something like a generated C++/native wrapper. It sounds like plans are in the work to do more caching sometime soon-ish to improve warm-start times, so this seems like a good place to start. I think this approach would mean running some of the Dynamo machinery on the generated wrapper. And then executing the FX graph. I would guess this means some kind of overhead for executing an FX graph -- either the process of turning the FX graph into runnable Python bytecode or some kind of FX graph interpreter. If that's the case, then this candidate solution may have wins reduced by whatever that FX -> executable Python step incurs. Any thoughts, comments or corrections on anything I wrote above, @jansel? Thanks! |
Related, I've been reading through the export docs here. The rationale section of that page walks through a number of comparisons that are useful to orient me. (I get that export is not what I'm looking for, but how it talks about the various ways of tracing I found to be useful background info as I start hunting through the code). |
Some difference in performance might be due to #122033 Also note that times under the profiler can be distorted due to profiling costs. So when I report numbers it is not with the profiler enabled. If you run without the patch you might get different numbers. I did some similar optimizations in this PR: https://github.com/pytorch/pytorch/pull/118070/files Where I moved stuff off the critical path by building a dynamic list of |
Addressing this comment re: baseline numbers:
As of commit e63e013, the following are the baseline perf numbers I'm getting (unpatched main branch, 1x H100): $ time python benchmarks/dynamo/microbenchmarks/overheads.py
requires_grad=False
eager 2.9us (warmup=0.0s)
compiled 9.2us (warmup=4.7s)
requires_grad=True
eager 3.8us (warmup=0.0s)
compiled 15.6us (warmup=0.1s)
inference_mode()
eager 2.8us (warmup=0.0s)
compiled 6.1us (warmup=0.0s)
python benchmarks/dynamo/microbenchmarks/overheads.py 51.45s user 12.34s system 131% cpu 48.421 total
In this setup, the autograd-checking wrapper code takes 6.4us longer (15.6us - 9.2us) for an increase of 70%. This is a lot better than the overhead as originally reported ( (39.0us - 14.0us) / 14.0us = 180% increase). I will still have a look at it, but the potential gains available here appear to have significantly reduced. |
This is a CPU benchmark, so likely the difference is the devserver CPU versus my local i9-11900K. It could also be python version (3.11). Latest numbers on my machine:
The exact numbers don't matter. I don't expect them to match across different machines. |
Yep, just wanted to start from a known position :-) |
That brings up a good question - which version of python are you using? I've done 3.10 and 3.11 but not sure what's standard. 3.12 wasn't working last time I tried but I recall hearing that is in the works. |
I'm using 3.11. 3.12 support is in progress but not done yet. |
Okay updated to Python 3.11 (I was on 3.10). I stuck some trace statements in the log to better understand how the code flowed during the benchmark. That generated quite a few lines of trace output ;-) Updated baseline numbers on my setup: requires_grad=False
eager 2.7us (warmup=0.0s)
compiled 6.9us (warmup=4.3s)
requires_grad=True
eager 3.3us (warmup=0.0s)
compiled 13.6us (warmup=0.1s)
inference_mode()
eager 2.5us (warmup=0.0s)
compiled 5.3us (warmup=0.0s) |
Getting back to this... Looks like this is the callstack to the relevant function: File "benchmarks/dynamo/microbenchmarks/overheads.py", line 41, in <module>
main()
File "benchmarks/dynamo/microbenchmarks/overheads.py", line 32, in main
bench("compiled", torch.compile(add1), True)
File "benchmarks/dynamo/microbenchmarks/overheads.py", line 18, in bench
fn(x)
File "torch/_dynamo/eval_frame.py", line 410, in _fn
return fn(*args, **kwargs)
File "benchmarks/dynamo/microbenchmarks/overheads.py", line 9, in add1
def add1(x):
File "torch/_dynamo/eval_frame.py", line 410, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 36, in inner
return fn(*args, **kwargs)
File "torch/_functorch/aot_autograd.py", line 917, in forward
return compiled_fn(full_args)
File "torch/_functorch/_aot_autograd/utils.py", line 89, in g
return f(*args)
File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 89, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "torch/_functorch/_aot_autograd/utils.py", line 89, in g
return f(*args)
File "torch/autograd/function.py", line 569, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 544, in forward
traceback.print_stack() (The last one is a callstack print mechanism I added to trace where we're at when we call that code). Just getting myself oriented. So somewhere up the callstack I'll want to find the highest place that could be setting code that is applicable and needs to be executed, in a style similar to this. Somewhere that has enough context to figure what applies to each call and can set up the relevant code fragments for execution. |
Hey @jansel! As I start messing with the code, is there anything in the overheads that validates (or could validate) that the results are correct? It would be awesome as I start experimenting to know that whatever I'm trying to do is a valid transformation. |
@tfiala for a basic smoke-test you could run:
(or pick any of the many test files in that directory) For more testing, if you open a draft PR with your changes our CI system will run the entire test suite and also some full-model testing to validate the change. Doing this full run takes many hours, so it is usually easier to let CI run it in parallel. Then if anything fails you can run just the failing test locally. |
A few months ago I looked at the fixed overheads of torch.compile using this microbenchmark.
I wanted to give a quick update to that, in that since I realized I didn’t test things with
requires_grad=True
which triggers a bunch of extra wrapper code in AOTAutograd.Here are, updated numbers from the last post with latest main. Sadly, we have regressed 3.7us since last time.
As an aside, if I enable
torch.inference_mode()
things get a bit better:Now on the other end, if I add
requires_grad=True
to this line of the microbenchmark, I get:Almost 2.8x worse than compile without the need to track gradients. Note I am not calling backward, just triggering extra wrapper code to deal with backward if it does get called.
So where is it coming from? Here is the top of a cProfile (run 1000 times):
It looks like the worst offender is coming form jit_compile_runtime_wrappers.py. As a quick experiment I wrote a patch that just skips that entire function:
This improves things to:
Obviously we can’t land that, but it does show the maximum benefit that could be had by optimizing that function. More than half of the overhead added by AOTAutograd is coming from that one function. Looking over that function’s code, it seems like there are a lot of checks for various corner cases that will do nothing on most invocations. In addition, nearly all of those corner case checks should be statically known at compile time. One way to speed that up would be to codegen a custom wrapper function that only includes the corner cases that are actually relevant to the function being called. Though some more basic trimming and code optimization would also likely help.
I don't plan to work on this, so this issue is up for grabs if someone is interested in some micro-optimizations.
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang
The text was updated successfully, but these errors were encountered: