-
Notifications
You must be signed in to change notification settings - Fork 25.4k
[dynamo] "TorchDynamo Cache Lookup" event: use C++ api #108436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dynamo] "TorchDynamo Cache Lookup" event: use C++ api #108436
Conversation
WIP [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108436
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1e296f3 with merge base db63bf3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
WIP cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov [ghstack-poisoned]
**Background**: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the `torch.profiler.record_function` context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup. **This PR**: Instead of calling the python bindings for `torch.profiler.record_function`, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much. **Performance results**: I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead. <details> <summary>Benchmarking script</summary> ```python def fn(x, y): return (x * y).relu() a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)] opt_fn = torch.compile(fn) opt_fn(a, b) opt_fn(a, b) with torch.profiler.profile() as prof: opt_fn(a, b) ``` </details> Median before PR: 198-228 us (median of 100, measured 5 times) Median after PR: 27us cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov [ghstack-poisoned]
torch/csrc/dynamo/cpp_shim.h
Outdated
typedef struct _PytorchRecordFunctionState { | ||
void* guard; | ||
} _PytorchRecordFunctionState; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just declare the struct here and forward declare the guard
part in cpp_shim.cpp?
struct _PytorchRecordFunctionState;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anijain2305 I had to change the return type to a _PytorchRecordFunctionState*
. Could you do a sanity check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just a minor comment on the shim wrapper to see if void* guard
can be avoid in header file.
**Background**: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the `torch.profiler.record_function` context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup. **This PR**: Instead of calling the python bindings for `torch.profiler.record_function`, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much. **Performance results**: I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead. <details> <summary>Benchmarking script</summary> ```python def fn(x, y): return (x * y).relu() a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)] opt_fn = torch.compile(fn) opt_fn(a, b) opt_fn(a, b) with torch.profiler.profile() as prof: opt_fn(a, b) ``` </details> Median before PR: 198-228 us (median of 100, measured 5 times) Median after PR: 27us cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov [ghstack-poisoned]
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Background: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the
torch.profiler.record_function
context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup.This PR: Instead of calling the python bindings for
torch.profiler.record_function
, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much.Performance results:
I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead.
Benchmarking script
Median before PR: 198-228 us (median of 100, measured 5 times)
Median after PR: 27us
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov