Skip to content

Conversation

davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Sep 1, 2023

Stack from ghstack (oldest at bottom):

Background: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the torch.profiler.record_function context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup.

This PR: Instead of calling the python bindings for torch.profiler.record_function, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much.

Performance results:
I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead.

Benchmarking script
def fn(x, y):
    return (x * y).relu()

a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)]

opt_fn = torch.compile(fn)

opt_fn(a, b)
opt_fn(a, b)

with torch.profiler.profile() as prof:
    opt_fn(a, b)

Median before PR: 198-228 us (median of 100, measured 5 times)
Median after PR: 27us

cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108436

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1e296f3 with merge base db63bf3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

davidberard98 added a commit that referenced this pull request Sep 1, 2023
WIP

ghstack-source-id: 1880fc8
Pull Request resolved: #108436
WIP

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
@davidberard98 davidberard98 changed the title [dynamo] TorchDynamo Cache Lookup - use C++ api [dynamo] "TorchDynamo Cache Lookup" event: use C++ api Sep 1, 2023
**Background**: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the `torch.profiler.record_function` context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup.

**This PR**: Instead of calling the python bindings for `torch.profiler.record_function`, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much.

**Performance results**:
I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead.

<details>

<summary>Benchmarking script</summary>

```python
def fn(x, y):
    return (x * y).relu()

a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)]

opt_fn = torch.compile(fn)

opt_fn(a, b)
opt_fn(a, b)

with torch.profiler.profile() as prof:
    opt_fn(a, b)
```

</details>

Median before PR: 198-228 us (median of 100, measured 5 times)
Median after PR: 27us

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
@davidberard98 davidberard98 marked this pull request as ready for review September 1, 2023 23:33
@davidberard98 davidberard98 requested a review from jansel September 1, 2023 23:43
Comment on lines 7 to 10
typedef struct _PytorchRecordFunctionState {
void* guard;
} _PytorchRecordFunctionState;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just declare the struct here and forward declare the guard part in cpp_shim.cpp?

struct _PytorchRecordFunctionState;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anijain2305 I had to change the return type to a _PytorchRecordFunctionState*. Could you do a sanity check?

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just a minor comment on the shim wrapper to see if void* guard can be avoid in header file.

**Background**: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the `torch.profiler.record_function` context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup.

**This PR**: Instead of calling the python bindings for `torch.profiler.record_function`, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much.

**Performance results**:
I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead.

<details>

<summary>Benchmarking script</summary>

```python
def fn(x, y):
    return (x * y).relu()

a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)]

opt_fn = torch.compile(fn)

opt_fn(a, b)
opt_fn(a, b)

with torch.profiler.profile() as prof:
    opt_fn(a, b)
```

</details>

Median before PR: 198-228 us (median of 100, measured 5 times)
Median after PR: 27us

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Sep 2, 2023
ghstack-source-id: 4656e9d
Pull Request resolved: #108436
@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 4, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/davidberard98/228/head branch September 7, 2023 14:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants