Skip to content

Conversation

anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Jul 10, 2024

Stack from ghstack (oldest at bottom):

Reduces the guard overhead from 2.1k units to 1k units. Compared to no-inlining (0.4k units), this reduces the slowdown from 5x to 2.5x.

This introduces unsoundness, but only for hooks for inbuilt nn modules (user defined nn module hooks are fine).

Each builtin nn module adds 4 empty ordered dict checks in the check_fn. This blows up for models with large numbers of builtin nn modules. With this PR, we skip those guards. There is no other easy way I can think of right now to control the guard overhead.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130420

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 25a934c with merge base b0a597f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Jul 10, 2024
ghstack-source-id: c10b5e6
Pull Request resolved: #130420
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Jul 11, 2024
ghstack-source-id: c06b7c6
Pull Request resolved: #130420
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@anijain2305 anijain2305 changed the title [dynamo] Skip guards on inbuilt nn module hooks [dynamo][unsoundness but very controlled] Skip guards on inbuilt nn module hooks Jul 12, 2024
…nbuilt nn module hooks"


Reduces the guard overhead from 2.1k units to 1k units. Compared to no-inlining, this reduces the slowdown from 5x to 2.5x.

This introduces unsoundness, but only for hooks for inbuilt nn modules (user defined nn module hooks are fine). 

Each builtin nn module adds 4 empty ordered dict checks in the check_fn. This blows up for models with large numbers of builtin nn modules. With this PR, we skip those guards. There is no other easy way I can think of right now.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jul 15, 2024
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…odule hooks (pytorch#130420)

Reduces the guard overhead from 2.1k units to 1k units. Compared to no-inlining (0.4k units), this reduces the slowdown from 5x to 2.5x.

This introduces unsoundness, but only for hooks for inbuilt nn modules (user defined nn module hooks are fine).

Each builtin nn module adds 4 empty ordered dict checks in the check_fn. This blows up for models with large numbers of builtin nn modules. With this PR, we skip those guards. There is no other easy way I can think of right now to control the guard overhead.

Pull Request resolved: pytorch#130420
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#130654
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
@github-actions github-actions bot deleted the gh/anijain2305/421/head branch August 15, 2024 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants