-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[dynamo][unsoundness but very controlled] Skip guards on inbuilt nn module hooks #130420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[ghstack-poisoned]
This was referenced Jul 10, 2024
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130420
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 25a934c with merge base b0a597f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This was referenced Jul 10, 2024
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…nbuilt nn module hooks" Reduces the guard overhead from 2.1k units to 1k units. Compared to no-inlining, this reduces the slowdown from 5x to 2.5x. This introduces unsoundness, but only for hooks for inbuilt nn modules (user defined nn module hooks are fine). Each builtin nn module adds 4 empty ordered dict checks in the check_fn. This blows up for models with large numbers of builtin nn modules. With this PR, we skip those guards. There is no other easy way I can think of right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
This was referenced Jul 13, 2024
jansel
approved these changes
Jul 15, 2024
pytorchmergebot
pushed a commit
that referenced
this pull request
Jul 15, 2024
Pull Request resolved: #130681 Approved by: https://github.com/zou3519, https://github.com/jansel ghstack dependencies: #130654, #130420
This was referenced Jul 15, 2024
Closed
xuhancn
pushed a commit
to xuhancn/pytorch
that referenced
this pull request
Jul 25, 2024
…odule hooks (pytorch#130420) Reduces the guard overhead from 2.1k units to 1k units. Compared to no-inlining (0.4k units), this reduces the slowdown from 5x to 2.5x. This introduces unsoundness, but only for hooks for inbuilt nn modules (user defined nn module hooks are fine). Each builtin nn module adds 4 empty ordered dict checks in the check_fn. This blows up for models with large numbers of builtin nn modules. With this PR, we skip those guards. There is no other easy way I can think of right now to control the guard overhead. Pull Request resolved: pytorch#130420 Approved by: https://github.com/jansel ghstack dependencies: pytorch#130654
xuhancn
pushed a commit
to xuhancn/pytorch
that referenced
this pull request
Jul 25, 2024
Pull Request resolved: pytorch#130681 Approved by: https://github.com/zou3519, https://github.com/jansel ghstack dependencies: pytorch#130654, pytorch#130420
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Reduces the guard overhead from 2.1k units to 1k units. Compared to no-inlining (0.4k units), this reduces the slowdown from 5x to 2.5x.
This introduces unsoundness, but only for hooks for inbuilt nn modules (user defined nn module hooks are fine).
Each builtin nn module adds 4 empty ordered dict checks in the check_fn. This blows up for models with large numbers of builtin nn modules. With this PR, we skip those guards. There is no other easy way I can think of right now to control the guard overhead.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames