Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AdamW optimizer #21250

Closed
wants to merge 6 commits into from
Closed

Implement AdamW optimizer #21250

wants to merge 6 commits into from

Conversation

mjacar
Copy link
Contributor

@mjacar mjacar commented Jun 1, 2019

What is this?

This is an implementation of the AdamW optimizer as implemented in the fastai library and as initially introduced in the paper Decoupled Weight Decay Regularization. It decouples the weight decay regularization step from the optimization step during training.

There have already been several abortive attempts to push this into pytorch in some form or fashion: #17468, #10866, #3740, #4429. Hopefully this one goes through.

Why is this important?

Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper Decoupled Weight Decay Regularization that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have.

How was this tested?

There were test cases added to test_optim.py and I also ran a little experiment to validate that this implementation is equivalent to the fastai implementation.

@pytorchbot pytorchbot added the module: optimizer Related to torch.optim label Jun 1, 2019
@pytorchbot pytorchbot added the module: docs Related to our documentation, both in docs/ and docblocks label Jun 1, 2019
@mjacar
Copy link
Contributor Author

mjacar commented Jun 18, 2019

@soumith Fair enough if it's an inherently unanswerable question, but is there any rough ETA on when this might get triaged?

@soumith
Copy link
Member

soumith commented Jun 20, 2019

cc: @vincentqb can you review this and get it to completion if it makes sense. Use the guidelines that I shared with you separately.

@mjacar i want to apologize for never getting a review to completion, but we now have Vincent with a lot of bandwidth and who is a math and optimization expert. He will help with the review.

@vincentqb
Copy link
Contributor

Thanks @mjacar for the PR. It looks good to me, though I wonder: is there a reason why you would like to change the default value of beta_2 from .999 to .99? The paper Decoupled Weight Decay Regularization also uses .999 as is PyTorch's Adam.

@mjacar
Copy link
Contributor Author

mjacar commented Jun 28, 2019

@vincentqb Literally just because that is the default value in the fastai implementation (see here).

Should I change it? I really don't have any strong feelings about it one way or the other.

@vincentqb vincentqb self-assigned this Jun 28, 2019
@vincentqb vincentqb self-requested a review June 28, 2019 22:01
@vincentqb
Copy link
Contributor

vincentqb commented Jun 28, 2019

@mjacar Ok, thanks for clarifying. We'll go with PyTorch's defaults. I can do that change.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@vincentqb
Copy link
Contributor

@pytorchbot retest this please

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

@vincentqb
Copy link
Contributor

@pytorchbot retest this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zhangguanheng66
Copy link
Contributor

@vincentqb I leave the PR for you to land. Thanks

@vincentqb
Copy link
Contributor

vincentqb commented Jul 2, 2019

Thanks for the AdamW implementation, and also to #22163, #3740, and #4429. This implements one of the algorithm requested in #3790, and discussed here.

@facebook-github-bot
Copy link
Contributor

@vincentqb merged this pull request in a4b2f3e.

@mjacar mjacar deleted the adamw branch July 2, 2019 18:32
This was referenced Jul 2, 2019
xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
Summary:
# What is this?
This is an implementation of the AdamW optimizer as implemented in [the fastai library](https://github.com/fastai/fastai/blob/803894051bef32304ceea0c8ea5e04db64ff26b8/fastai/callback.py) and as initially introduced in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). It decouples the weight decay regularization step from the optimization step during training.

There have already been several abortive attempts to push this into pytorch in some form or fashion: pytorch#17468, pytorch#10866, pytorch#3740, pytorch#4429. Hopefully this one goes through.
# Why is this important?
Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have.
# How was this tested?
There were test cases added to `test_optim.py` and I also ran a [little experiment](https://gist.github.com/mjacar/0c9809b96513daff84fe3d9938f08638) to validate that this implementation is equivalent to the fastai implementation.
Pull Request resolved: pytorch#21250

Differential Revision: D16060339

Pulled By: vincentqb

fbshipit-source-id: ded7cc9cfd3fde81f655b9ffb3e3d6b3543a4709
@PetrochukM
Copy link

PetrochukM commented Aug 12, 2019

Hey!

Are we sure this AdamW implementation is correct?

In the above implementation, we have:
p.data.mul_(1 - group['lr'] * group['weight_decay'])

In the below paper, α (learning rate) is never multiplied by the weight_decay.
image

Furthermore, if it is multiplied by the learning rate, should that be bias-corrected?

@PetrochukM
Copy link

The original AdamW code does factor learning rate in and ignores bias correction but they do some other interesting things:
https://github.com/loshchil/AdamW-and-SGDW/blob/34fc69c20dfef618ad9770b92921c928440e32f7/train.lua#L76

They end up with regularization of:

weightDecaycurrent = (self.optimState.learningRate / self.opt.LR) * (self.optimState.weightDecay / ( math.pow(trainSize * self.opt.Te, 0.5) ))

This is then multiplied by the parameter vector θt−1.

@vincentqb
Copy link
Contributor

vincentqb commented Aug 12, 2019

In the above implementation, we have:
p.data.mul_(1 - group['lr'] * group['weight_decay'])

In the below paper, α (learning rate) is never multiplied by the weight_decay.
image

In algorithm 2, the (initial) learning rate α is multiplied by the scheduler η_t, so the effective multiplier on the gradient is αη_t. Similarly, the effective multiplier on the weights is λη_t where λ is the (initial) weight rate.

In our case, since we only have a learning rate scheduler group['lr'] = αη_t (let's double check this), so the effective multiplier on the weights is λη_t = λ (group['lr']/α). Thus, since we currently have the effective multiplier on the weights (λη_t) as group['weight_decay'] * group['lr'], we currently have group['weight_decay'] = λ/α.

I agree that this can be confusing, and the equivalent meaning of group['lr'] should be group['weight_decay'] = λη_t and not λ/α. This also means that when we introduce a weight scheduler in #22343, we can naturally schedule group['weight_decay'] independently of group['lr'].

Furthermore, if it is multiplied by the learning rate, should that be bias-corrected?

Good question. I'd say no because we are using group['lr'] to recover the implicit scheduler η_t here which is not bias-corrected in algorithm 2.

Thoughts?

@PetrochukM
Copy link

Thanks for clarifying this for me. I agree with you.

I think it's interesting that the AdamW in the paper is different from the AdamW in the PyTorch and the original code base, based on my best understanding. I believe there are three different implementations of AdamW.

The documentation for AdamW doesn't mention some of the discrepancies, I think it might be worthwhile to mention those:
image

I'm happy that we have them documented in this PR though.

@vincentqb
Copy link
Contributor

vincentqb commented Nov 25, 2019

Would you like to open a PR to add to the documentation? :)

@PetrochukM
Copy link

Not really because I don't feel like I understand the issue that well. If anything, I'd just link this PR in the documentation.

@sh0416
Copy link

sh0416 commented May 22, 2020

@vincentqb I want to disable bias correction in the AdamW optimizer. It is known that BERT uses AdamW with disabling bias correction. HuggingFace also adopts this approach in their optimizer. I think there is no reason not to make bias correction flag in the AdamW optimizer. Can I make a bias_correction parameter inside AdamW optimizer?

@vincentqb
Copy link
Contributor

@vincentqb I want to disable bias correction in the AdamW optimizer. It is known that BERT uses AdamW with disabling bias correction. HuggingFace also adopts this approach in their optimizer. I think there is no reason not to make bias correction flag in the AdamW optimizer. Can I make a bias_correction parameter inside AdamW optimizer?

Wouldn't this be equivalent to having beta1 = beta2 = 0? If so, you can already do so.

One thing that I would be ok with would be to make a special case of beta1 = 0 and beta2 = 0, but the only advantage is to potentially save on computation.

@sh0416
Copy link

sh0416 commented May 29, 2020

@vincentqb I don’t think it is the special case you mentioned. Adam estimates the first and second moment and it could not estimate well in the first few iteration. Therefore, it uses bias correction equation using the step t. If the step t is large, the effect of bias correction term will vanish. However, in BERT tf repo, the author remove the bias correction term which is empirically stable and improve the performance. It means that the bias correction doesn’t work well. Therefore, I think it will be great that a user can choose whether to use bias correction or not. Although this phenomena is proved only in BERT case, it could be valuable considering the impact of BERT.
Anyway, thanks for the reply.

Estimation part: moment = memont * beta + cur_moment * (1 - beta)
Bias correction part: moment = moment / (1 - beta^t)
Simply making beta zero disable the momentum term, which is not I want.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: docs Related to our documentation, both in docs/ and docblocks module: optimizer Related to torch.optim open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants