New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce overhead in CUDAGraph Trees #98529
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98529
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7a9b0a6: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: ced324b0fbe76056ccc3682e407e3065a79d21eb Pull Request resolved: #98529
cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: d820ef3702340468a90b5189ec37b808b3018427 Pull Request resolved: #98529
cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 466882054aa7cb281a1578cb4435fc882704ab5b Pull Request resolved: #98529
cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 4d2fb1f8988a21330a17057bb3d7e428ea1ea78b Pull Request resolved: #98529
cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 98ad5be85f5e16b933c1571772618f60d8e0b859 Pull Request resolved: #98529
cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 0eaaac4c8c353ca55ac57f6bb5b13f5384ebb095 Pull Request resolved: #98529
cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 27bbfa172ea0f5ba67e6cf1308b13b606a1aaa33 Pull Request resolved: #98529
cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: c2b8987832004cab8a7dbe4f8904ecf8b082012a Pull Request resolved: #98529
cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
It's still slower than previous cudagraphs implementation was? Do you have an idea why? |
@ngimel it's not really slower on any real model, (same speed for huggingface, etc) but resnet in torchbench the graph.replay() is so quick that just reconstructing the tensors and storages is a noticeable overhead. We also do not have a well tuned batch size for that model. It's not really about graph breaks - even just with no graph breaks, there is still a substantial memory overhead in the existing implementation because you are not sharing memory between forward and backward. For inference, there are so few tensors output compared to training that there should be practically no difference. A further training optimization would be to reuse the same TensorImpls for saved tensors for the backward since we know they will be live if the backward hasn't been run, and do not escape the backward. |
Significantly reduces overhead of constructing Tensors and Storages and checking Storage Liveness. Removes the regression for HF models that I tested and removes 75% of overhead of the extremely overhead bound resnet50 training we have in torchbench. (.91x base commit, 1.02x torchinductor default, 1.16x this PR, 1.25 previous cudagraphs impl). This PR takes care of all of the lower hanging fruit. - Computes storage aliasing at record time instead of during at runtime. We no longer need to use a runtime storage cache, and can instead index directly into the existing alias if there is one, or construct a new Storage - Moves the heavyweight C++ calls into a batch - getting storage weakrefs and constructing tensors cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Significantly reduces overhead of constructing Tensors and Storages and checking Storage Liveness. Removes the regression for HF models that I tested and removes 75% of overhead of the extremely overhead bound resnet50 training we have in torchbench. (.91x base commit, 1.02x torchinductor default, 1.16x this PR, 1.25 previous cudagraphs impl). This PR takes care of all of the lower hanging fruit. - Computes storage aliasing at record time instead of during at runtime. We no longer need to use a runtime storage cache, and can instead index directly into the existing alias if there is one, or construct a new Storage - Moves the heavyweight C++ calls into a batch - getting storage weakrefs and constructing tensors Pull Request resolved: #98529 Approved by: https://github.com/jansel, https://github.com/ngimel
Stack from ghstack (oldest at bottom):
Significantly reduces overhead of constructing Tensors and Storages and checking Storage Liveness. Removes the regression for HF models that I tested and removes 75% of overhead of the extremely overhead bound resnet50 training we have in torchbench. (.91x base commit, 1.02x torchinductor default, 1.16x this PR, 1.25 previous cudagraphs impl).
This PR takes care of all of the lower hanging fruit.
Computes storage aliasing at record time instead of during at runtime. We no longer need to use a runtime storage cache, and can instead index directly into the existing alias if there is one, or construct a new Storage
Moves the heavyweight C++ calls into a batch - getting storage weakrefs and constructing tensors
cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire