-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add capturable single-tensor RAdam, Adamax #118697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add capturable single-tensor RAdam, Adamax #118697
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118697
Note: Links to docs will display an error until the docs builds have been completed. ❌ 9 New Failures, 1 Unrelated FailureAs of commit ec1c622 with merge base 1adedc3 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@MarouaneMaatouk I think you did Adamax actually (which was also needed!) but can you add tests here: https://github.com/pytorch/pytorch/pull/117912/files#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450 for the cudagraphs support? cc @janeyx99 for the status on capturable testing with optiminfos, I think we might need to wait for #118326 to go in to make sure we're testing the capturable for Adamax properly. |
No need to wait for me, but do make the Adamax related changes in common_optimizer.py in the linked PR. I don’t want my change to block anything! |
@mlazos adamax tests pass, but radam fails and I wasn't able to make sense of the error maybe I am missing something.
|
Remove this code: https://github.com/pytorch/pytorch/blob/923a7c757205a327e8f8a6d4f9dbda036ae531d3/torch/_dynamo/eval_frame.py#L1571C4-L1572C10 |
Please seek CI approval before scheduling CIFlow labels |
Thanks, done in the last commit. However I am still having some issues due to this condition (https://github.com/pytorch/pytorch/pull/118697/files#diff-4e7620901810b83e6a28709cbb678170338937eae1e3949b1c1e295d803cca68R368), using directly if wasn't working. |
oh if you look at the multitensor version at the bottom of the same file I use |
Please seek CI approval before scheduling CIFlow labels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will look more closely sometime after my meetings today/tomorrow but noticed two things from a brief glance
optim_error_inputs_func=optim_error_inputs_func_adamax, | ||
supported_impls=("foreach", "differentiable"), | ||
only_supports_capturable_on_foreach=True, # Remove this line when #117836 is done! | ||
only_supports_capturable_on_foreach=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please delete the whole line!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And also remove the skipped tests for Adamax as well
Please seek CI approval before scheduling CIFlow labels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarouaneMaatouk It looks like there is failing CI currently. I am also noticing that the CUDA graph tests (the ones that test the point of capturable) are missing. Add in the single tensor variants here: https://github.com/pytorch/pytorch/blob/main/test/test_cuda.py#L2695-L2766
Let us know if you need any help!
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop. Next steps: * This PR discovered two bugs: #121178 and #121238. * Move the now hefty graph optim tests in test_cuda to use OptimInfo. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop. Next steps: * This PR discovered two bugs: #121178 and #121238. * Move the now hefty graph optim tests in test_cuda to use OptimInfo. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop. Next steps: * This PR discovered two bugs: #121178 and #121238. * Move the now hefty graph optim tests in test_cuda to use OptimInfo. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop. Next steps: * This PR discovered two bugs: #121178 and #121238. * Move the now hefty graph optim tests in test_cuda to use OptimInfo. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Implementation thanks to MarouaneMaatouk in #118697. Added tests and the cudagraph health check. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop. Next steps: * This PR discovered two bugs: #121178 and #121238. * Move the now hefty graph optim tests in test_cuda to use OptimInfo. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Implementation thanks to MarouaneMaatouk in #118697. Added tests and the cudagraph health check. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop. Next steps: * This PR discovered two bugs: #121178 and #121238. * Move the now hefty graph optim tests in test_cuda to use OptimInfo. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Implementation thanks to MarouaneMaatouk in #118697. Added tests and the cudagraph health check. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Implementation thanks to MarouaneMaatouk in #118697. Added tests and the cudagraph health check. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Finishes the work started in #118697. Thanks @MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop. Next steps: * This PR discovered two bugs: #121178 and #121238. * Move the now hefty graph optim tests in test_cuda to use OptimInfo. Pull Request resolved: #121183 Approved by: https://github.com/albanD
Implementation thanks to MarouaneMaatouk in #118697. Added tests and the cudagraph health check. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Implementation thanks to MarouaneMaatouk in #118697. Added tests and the cudagraph health check. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Implementation thanks to @MarouaneMaatouk in #118697, though I've since cleaned it up a lot to save perf on the rect < 5 eager case. It also just looks better now :) Added tests and the cudagraph health check. Pull Request resolved: #121260 Approved by: https://github.com/mlazos
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
this has been done, closing |
Fixes #118230, #117836
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang