-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement AdamW optimizer #21250
Implement AdamW optimizer #21250
Conversation
@soumith Fair enough if it's an inherently unanswerable question, but is there any rough ETA on when this might get triaged? |
cc: @vincentqb can you review this and get it to completion if it makes sense. Use the guidelines that I shared with you separately. @mjacar i want to apologize for never getting a review to completion, but we now have Vincent with a lot of bandwidth and who is a math and optimization expert. He will help with the review. |
Thanks @mjacar for the PR. It looks good to me, though I wonder: is there a reason why you would like to change the default value of beta_2 from .999 to .99? The paper Decoupled Weight Decay Regularization also uses .999 as is PyTorch's Adam. |
@vincentqb Literally just because that is the default value in the fastai implementation (see here). Should I change it? I really don't have any strong feelings about it one way or the other. |
@mjacar Ok, thanks for clarifying. We'll go with PyTorch's defaults. I can do that change. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vincentqb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@pytorchbot retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me.
@pytorchbot retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vincentqb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@vincentqb I leave the PR for you to land. Thanks |
@vincentqb merged this pull request in a4b2f3e. |
Summary: # What is this? This is an implementation of the AdamW optimizer as implemented in [the fastai library](https://github.com/fastai/fastai/blob/803894051bef32304ceea0c8ea5e04db64ff26b8/fastai/callback.py) and as initially introduced in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). It decouples the weight decay regularization step from the optimization step during training. There have already been several abortive attempts to push this into pytorch in some form or fashion: pytorch#17468, pytorch#10866, pytorch#3740, pytorch#4429. Hopefully this one goes through. # Why is this important? Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have. # How was this tested? There were test cases added to `test_optim.py` and I also ran a [little experiment](https://gist.github.com/mjacar/0c9809b96513daff84fe3d9938f08638) to validate that this implementation is equivalent to the fastai implementation. Pull Request resolved: pytorch#21250 Differential Revision: D16060339 Pulled By: vincentqb fbshipit-source-id: ded7cc9cfd3fde81f655b9ffb3e3d6b3543a4709
The original AdamW code does factor learning rate in and ignores bias correction but they do some other interesting things: They end up with regularization of:
This is then multiplied by the parameter vector θt−1. |
In algorithm 2, the (initial) learning rate α is multiplied by the scheduler η_t, so the effective multiplier on the gradient is αη_t. Similarly, the effective multiplier on the weights is λη_t where λ is the (initial) weight rate. In our case, since we only have a learning rate scheduler group['lr'] = αη_t (let's double check this), so the effective multiplier on the weights is λη_t = λ (group['lr']/α). Thus, since we currently have the effective multiplier on the weights (λη_t) as group['weight_decay'] * group['lr'], we currently have group['weight_decay'] = λ/α. I agree that this can be confusing, and the equivalent meaning of group['lr'] should be group['weight_decay'] = λη_t and not λ/α. This also means that when we introduce a weight scheduler in #22343, we can naturally schedule group['weight_decay'] independently of group['lr'].
Good question. I'd say no because we are using group['lr'] to recover the implicit scheduler η_t here which is not bias-corrected in algorithm 2. Thoughts? |
Would you like to open a PR to add to the documentation? :) |
Not really because I don't feel like I understand the issue that well. If anything, I'd just link this PR in the documentation. |
@vincentqb I want to disable bias correction in the AdamW optimizer. It is known that BERT uses AdamW with disabling bias correction. HuggingFace also adopts this approach in their optimizer. I think there is no reason not to make bias correction flag in the AdamW optimizer. Can I make a bias_correction parameter inside AdamW optimizer? |
Wouldn't this be equivalent to having One thing that I would be ok with would be to make a special case of |
@vincentqb I don’t think it is the special case you mentioned. Adam estimates the first and second moment and it could not estimate well in the first few iteration. Therefore, it uses bias correction equation using the step t. If the step t is large, the effect of bias correction term will vanish. However, in BERT tf repo, the author remove the bias correction term which is empirically stable and improve the performance. It means that the bias correction doesn’t work well. Therefore, I think it will be great that a user can choose whether to use bias correction or not. Although this phenomena is proved only in BERT case, it could be valuable considering the impact of BERT. Estimation part: moment = memont * beta + cur_moment * (1 - beta) |
What is this?
This is an implementation of the AdamW optimizer as implemented in the fastai library and as initially introduced in the paper Decoupled Weight Decay Regularization. It decouples the weight decay regularization step from the optimization step during training.
There have already been several abortive attempts to push this into pytorch in some form or fashion: #17468, #10866, #3740, #4429. Hopefully this one goes through.
Why is this important?
Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper Decoupled Weight Decay Regularization that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have.
How was this tested?
There were test cases added to
test_optim.py
and I also ran a little experiment to validate that this implementation is equivalent to the fastai implementation.