-
Notifications
You must be signed in to change notification settings - Fork 24.6k
Optim foreach cleanup for NAdam #70229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optim foreach cleanup for NAdam #70229
Conversation
[ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow For more information, please take a look at the CI Flow Wiki. |
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 69ca308 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
torch/optim/nadam.py
Outdated
mus = [beta1 * (1. - 0.5 * (0.96 ** (step * momentum_decay))) for step in state_steps] | ||
mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((step + 1) * momentum_decay))) | ||
for step in state_steps] | ||
mu_products = [mu * mu_product for mu, mu_product in zip(mus, mu_products)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The multitensor Nadam class does an update on mu_product before the call to nadam and then does computations on mu_products
directly whereas the singletensor class does an update on mu_product after the call to nadam and computes the new mu_products
within the nadam function. In this combined class I preserved the singletensor behavior so I added line 245 here to ensure the same behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be done with a foreach op no?
Also the existing code is modifying mu_products inplace, don't we want to preserve that/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be done with a foreach op no?
Also the existing code is modifying mu_products inplace, don't we want to preserve that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm I didn't do this in place because each mu_product
is a float rather than a tensor so even if it's done in-place i don't think it updates the underlying state['mu_product']
, but there is an update step within the NAdam.step function on line 146 which updates state['mu_product']
, which is preserved from the single tensor version, am I thinking about this correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't you computing the same thing twice here then?
I am a bit confused now where this value is updated for each case.
In general, I think we want to stay as close to the original code as possible. Even if we have to fold some of the state update in the functional function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree that the same thing is being computed twice, but that was what was being done initially in torch/optim/nadam.py. The single tensor version does the update of mu_product after calling F.nadam (and this same computation is done in the function on lines 224 and 226) whereas the multitensor one does it before calling F.nadam and so the logic isn't repeated in the function! So I think we need to change the code of either one of the functional forms (for example if I preserve the initial multitensor version and get rid of this change to _multi_tensor_nadam I think I would have to remove line 226 of _single_tensor_nadam (mu_product = mu_product * mu
), does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two options I see are:
- We want to keep the exact same behavior as before and so we fold the second update for the single Tensor impl into the functional. And fold the multitensor computation into its functional version.
- We want to remove this duplicate code and so modify one of the two to match the other. In such case, I think we should keep the version that does not do the computation twice as it is better! Also the functional version should perform the full step, so I would argue it is a bug that part of it is done in the optimizer outside of the functional call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm which option do you think I should go ahead with? It seems that ASGD and SGD have the same bug where part of the state is updated outside the functional form in single tensor ASGD and SGD as well. I think they did this (for ASGD and NAdam) because arguments were passed as floats rather than singleton tensors so they couldn't update state within the function.
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much cleaner!
[ghstack-poisoned]
@mikaylagawarecki has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Differential Revision: [D33767873](https://our.internmc.facebook.com/intern/diff/D33767873) [ghstack-poisoned]
@mikaylagawarecki has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Pull Request resolved: #70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Summary: Pull Request resolved: pytorch/pytorch#70229 Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D33767873 Pulled By: mikaylagawarecki fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad (cherry picked from commit 9415df6b5c9620c9d53036c28fe3f297c6d4906c)
Stack from ghstack:
Add foreach flag to NAdam optimizer + cleanup
Differential Revision: D33767873