Skip to content

Conversation

@davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Dec 13, 2023

Stack from ghstack (oldest at bottom):

Motivation: it would be nice to be able to test using the metrics in log_compilation_event; currently dumps logs (or logs to a database in fbcode) - these are hard to use in unit tests.

This change:

  • always record the information in torch._dynamo.utils.record_compilation_metrics; here, log into a limited-size deque to prevent the list of metrics from getting too long
  • if config.log_compilation_metrics, then call back into the original log_compilation_event function

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

Motivation: for testing in dynamo, it would be nice to have an equivalent to torch._inductor.metrics. We already have log_compilation_event, but that only prints logs or dumps to a database (in fbcode); these aren't appropriate for unit tests.

This change: adds torch._dynamo.metrics, and then changes convert_frame so that it calls torch._dynamo.metrics.record_compilation_metrics instead. To prevent a super large list from forming, use a deque with a max length; then if the max length is exceeded, the oldest records will be dropped.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115788

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit a3a2b1e with merge base 87ea6fb (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…ics"

Motivation: for testing in dynamo, it would be nice to have an equivalent to torch._inductor.metrics. We already have log_compilation_event, but that only prints logs or dumps to a database (in fbcode); these aren't appropriate for unit tests.

This change: adds torch._dynamo.metrics, and then changes convert_frame so that it calls torch._dynamo.metrics.record_compilation_metrics instead. To prevent a super large list from forming, use a deque with a max length; then if the max length is exceeded, the oldest records will be dropped.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
…ics"


Motivation: it would be nice to be able to test using the metrics in log_compilation_event; currently dumps logs (or logs to a database in fbcode) - these are hard to use in unit tests.

This change: 
* always record the information in torch._dynamo.utils.record_compilation_metrics; here, log into a limited-size deque to prevent the list of metrics from getting too long
* if config.log_compilation_metrics, then call back into the original log_compilation_event function

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
@davidberard98 davidberard98 changed the title [dynamo] Log all compilation metrics in torch._dynamo.metrics [dynamo] Store CompilationEvents in a buffer in torch._dynamo.utils Dec 14, 2023
@davidberard98 davidberard98 marked this pull request as ready for review December 14, 2023 03:13
@davidberard98
Copy link
Contributor Author

@yanboliang should we test compile time on this to make sure nothing regresses or do you think it should be safe?

Copy link
Contributor

@yanboliang yanboliang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to run benchmark against this change, though I think it's low risk.

…amo.utils"


Motivation: it would be nice to be able to test using the metrics in log_compilation_event; currently dumps logs (or logs to a database in fbcode) - these are hard to use in unit tests.

This change: 
* always record the information in torch._dynamo.utils.record_compilation_metrics; here, log into a limited-size deque to prevent the list of metrics from getting too long
* if config.log_compilation_metrics, then call back into the original log_compilation_event function

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Dec 14, 2023
Motivation: for testing in dynamo, it would be nice to have an equivalent to torch._inductor.metrics. We already have log_compilation_event, but that only prints logs or dumps to a database (in fbcode); these aren't appropriate for unit tests.

This change: adds torch._dynamo.metrics, and then changes convert_frame so that it calls torch._dynamo.metrics.record_compilation_metrics instead. To prevent a super large list from forming, use a deque with a max length; then if the max length is exceeded, the oldest records will be dropped.

ghstack-source-id: ea76a5b
Pull Request resolved: #115788
…amo.utils"


Motivation: it would be nice to be able to test using the metrics in log_compilation_event; currently dumps logs (or logs to a database in fbcode) - these are hard to use in unit tests.

This change: 
* always record the information in torch._dynamo.utils.record_compilation_metrics; here, log into a limited-size deque to prevent the list of metrics from getting too long
* if config.log_compilation_metrics, then call back into the original log_compilation_event function

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Dec 16, 2023
Motivation: for testing in dynamo, it would be nice to have an equivalent to torch._inductor.metrics. We already have log_compilation_event, but that only prints logs or dumps to a database (in fbcode); these aren't appropriate for unit tests.

This change: adds torch._dynamo.metrics, and then changes convert_frame so that it calls torch._dynamo.metrics.record_compilation_metrics instead. To prevent a super large list from forming, use a deque with a max length; then if the max length is exceeded, the oldest records will be dropped.

ghstack-source-id: d124c19
Pull Request resolved: #115788
@davidberard98
Copy link
Contributor Author

@davidberard98 davidberard98 added the topic: not user facing topic category label Dec 16, 2023
…amo.utils"


Motivation: it would be nice to be able to test using the metrics in log_compilation_event; currently dumps logs (or logs to a database in fbcode) - these are hard to use in unit tests.

This change: 
* always record the information in torch._dynamo.utils.record_compilation_metrics; here, log into a limited-size deque to prevent the list of metrics from getting too long
* if config.log_compilation_metrics, then call back into the original log_compilation_event function

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 18, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Dec 20, 2023
Summary:
Motivation: it would be nice to be able to test using the metrics in log_compilation_event; currently dumps logs (or logs to a database in fbcode) - these are hard to use in unit tests.

This change:
* always record the information in torch._dynamo.utils.record_compilation_metrics; here, log into a limited-size deque to prevent the list of metrics from getting too long
* if config.log_compilation_metrics, then call back into the original log_compilation_event function

X-link: pytorch/pytorch#115788
Approved by: https://github.com/yanboliang

Reviewed By: jeanschmidt

Differential Revision: D52298053

Pulled By: davidberard98

fbshipit-source-id: ef291255d6148c0479f3000b4fb21a4ed72cadcb
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
…ytorch#115788)

Motivation: it would be nice to be able to test using the metrics in log_compilation_event; currently dumps logs (or logs to a database in fbcode) - these are hard to use in unit tests.

This change:
* always record the information in torch._dynamo.utils.record_compilation_metrics; here, log into a limited-size deque to prevent the list of metrics from getting too long
* if config.log_compilation_metrics, then call back into the original log_compilation_event function

Pull Request resolved: pytorch#115788
Approved by: https://github.com/yanboliang
@facebook-github-bot facebook-github-bot deleted the gh/davidberard98/253/head branch December 22, 2023 15:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants