Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend pytorch multi tensor rmsprop optimizer to support lr_in_momentum and decoupled_decay. #46118

Closed
wants to merge 1 commit into from

Conversation

lly-zero-one
Copy link
Contributor

@lly-zero-one lly-zero-one commented Oct 9, 2020

Summary: Switch to use the multi-tensor version RMSProp optimizer in the classy vision flow and also make the numeric match with old one

Test Plan: Flow canary: f223946625

Differential Revision: D24102016

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24102016

@dr-ci
Copy link

dr-ci bot commented Oct 9, 2020

💊 CI failures summary and remediations

As of commit 82c0de1 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🚧 5 fixed upstream failures:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is newer than viable/strict, you can try basing on an older, stable commit:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase --onto FETCH_HEAD $(git merge-base origin/master HEAD)

If your commit is older than viable/strict:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 28 times.


"""

def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, lr_in_momentum=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not standard in pytorch, and doesn't seem like a standard type of argument for optimizers. What are alternatives?

Is the goal only to improve performance? In that case it would be good to have a benchmark.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, make sense. I think I need to change the title for OSS. Basically, we switch to use the multi-tensor version for an internal flow and tried to make the accuracy match, but that flow related change was not pulled out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb could you help to review the extension and also the numeric change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the API change, I do not see a reason to add this in the API of RMSProp. No other optimizers have that.

If there is a runtime speed improvement for running the step, I don't believe this would justify adding this API, since this is still a change in the RMSProp algorithm we have.

As I mention in comment, if there is a discrepancy between RMSProp with and without multitensor, this needs to be investigated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change relates to #23796, and the difference between pytorch and tensorflow -- and not between the original implementation and the multi-tensor.

@@ -384,7 +384,7 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
// Check for errors and throw appropriate exception.
checkAndThrowException();
std::this_thread::sleep_for(
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
std::chrono::microseconds(kSynchronizeBusyWaitMicros));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We hope to get rid of this busy wait in #45236, but I guess its fine to land this change if it is something urgent for classy vision.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was not aware this file is pulled into this PR. I will remove it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, although I was wondering if we were planning on making these changes as part of a separate PR?

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24102016

@dzhulgakov
Copy link
Collaborator

Can you update the PR description and maybe include a standalone benchmark results? (torch.utils.benchmark might be helpful)

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24102016

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24102016

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24102016

…torch#46118)

Summary:
Pull Request resolved: pytorch#46118

Switch to use the multi-tensor version RMSProp optimizer in the classy vision flow and also make the numeric match with old one

Test Plan: Flow canary: f223946625

Differential Revision: D24102016

fbshipit-source-id: 362d525bc3e1728ee736e9806a49c5931e8b1cd5
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D24102016

@lly-zero-one lly-zero-one changed the title Use pytorch multi tensor rmsprop optimizer for better performance Extend pytorch multi tensor rmsprop optimizer to support lr_in_momentum and decoupled_decay. Oct 22, 2020
@vincentqb
Copy link
Contributor

Summary: Switch to use the multi-tensor version RMSProp optimizer in the classy vision flow and also make the numeric match with old one

@lly-zero-one -- I would like more clarity about the switch to multi-tensor, and "make the numeric match with old one". Are you saying the implementation of multi-tensor doesn't match the implementation without? If this is so, then we need to add a test that can confirm that though it is hard to do in the open source without the internal models. We can sync up offline for this.

@ngimel
Copy link
Collaborator

ngimel commented Oct 29, 2020

@vincentqb it looks like there's sufficient demand for tf-style RMSProp optimizer and we are doing pytorch users a disservice by not offering it. I understand that for bc reasons we may be reluctant to change default RMSProp behavior, but then it makes sense to have RMSProp_tf optimizer (name can be anything) that does what's requested in #23796 and what @rwightman's linked optimizer does.

@lly-zero-one
Copy link
Contributor Author

I could add an option to support the two versions.

@lly-zero-one
Copy link
Contributor Author

@vincentqb it looks like there's sufficient demand for tf-style RMSProp optimizer and we are doing pytorch users a disservice by not offering it. I understand that for bc reasons we may be reluctant to change default RMSProp behavior, but then it makes sense to have RMSProp_tf optimizer (name can be anything) that does what's requested in #23796 and what @rwightman's linked optimizer does.

Some comments from internal team:

Seems that the pytorch rms requires a much smaller learning rate.  Currently the typical lr for tf_rms is ~0.1-0.2, while using such learning rate in pytorch_rms would lead to divergence.  I don't have extensive experiments on pytorch_rms.

We have reproduced many sota models with tf_rms.  But for pytorch_rms, some efforts are required to tune the hyper-parameters.

@facebook-github-bot
Copy link
Contributor

Hi @lly-zero-one!

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but we do not have a signature on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@ngimel
Copy link
Collaborator

ngimel commented Nov 26, 2020

@vincentqb what's required to move this PR forward and provide rms version consistent with tf behavior? It seems to be requested internally and externally.

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 26, 2020
@vincentqb vincentqb added module: optimizer Related to torch.optim and removed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Dec 18, 2020
@vincentqb
Copy link
Contributor

Thanks for the pull request @lly-zero-one. As far as I can see, there are mainly three changes requested in this PR.

@lly-zero-one, I will close this pull request for now, but please feel free to open an issue for those points that are still needed.

@vincentqb vincentqb closed this Dec 18, 2020
@@ -78,7 +102,7 @@ def step(self, closure=None):
raise RuntimeError('RMSprop does not support sparse gradients')

grads.append(p.grad)
params_with_grad.append(p)
params_with_grad.append(p.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We recently cleaned up the use of .data inside optimizers. Are you aware of its interaction with multitensor?

@rwightman
Copy link

rwightman commented Dec 18, 2020

@vincentqb decoupled decay works perfectly well with rmsprop and some other optimizers besides the AdamW/SGDW of that paper. Since I was just working on some JAX code fixing up the Flax RMSProp, it's common to see optimizers there (various JAX libs) where any weight_decay applied with the optimizer is (usually) decoupled decay and L2 penalty (equiv to weight_decay here) is applied outside of the optimizer.

Since it doesn't look like these changes are going anywhere fast, I think I'll tackle the multi-tensor variant in timm ... been using it a lot so probably worthwhile. It should be noted that this changeset was still short one important difference for reproducing Google papers using RMSProp + TF, rms state init.

@arikanev
Copy link

Are there any plans to revive this or do we just have to use for example, @rwightman methodology?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed fb-exported module: optimizer Related to torch.optim open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants