Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difference in the implementation of rmsprop with Tensorflow #23796

Closed
meijieru opened this issue Aug 5, 2019 · 14 comments
Closed

Difference in the implementation of rmsprop with Tensorflow #23796

meijieru opened this issue Aug 5, 2019 · 14 comments
Assignees
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@meijieru
Copy link
Contributor

meijieru commented Aug 5, 2019

🚀 Feature

Recently I want to reproduce a tensorflow model in pytorch. I found some differences between tensorflow and pytorch for rmsprop optimizer. The epsilon is added inside the sqrt in tensorflow while pytorch add them outside of it. It makes a difference when epsilon is large.

See chainer/chainer#4754 for reference. Maybe we could have the same option eps_inside_sqrt for controlling the behavior.

meijieru added a commit to meijieru/pytorch that referenced this issue Aug 5, 2019
@ailzhang ailzhang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: optimizer Related to torch.optim labels Aug 7, 2019
@ailzhang
Copy link
Contributor

ailzhang commented Aug 7, 2019

Thanks for the PR @meijieru !

@vincentqb
Copy link
Contributor

I'm reopening in case other users would like to give feedback on this.

@vincentqb vincentqb reopened this Sep 24, 2019
vincentqb added a commit that referenced this issue Sep 24, 2019
…g epsilon. #23796

ghstack-source-id: c6732741c0d21d9fea4c7b74ce5e81a02e972ecc
Pull Request resolved: #26735
vincentqb added a commit that referenced this issue Sep 24, 2019
…g epsilon. #23796

ghstack-source-id: 03f67e94e355a74c5cf1ce57cfed6af587495088
Pull Request resolved: #26735
vincentqb added a commit that referenced this issue Sep 25, 2019
…g epsilon. #23796

ghstack-source-id: 6c4dbd396edeb987c422ec69fa32b60840b3d108
Pull Request resolved: #26735
@1e100
Copy link

1e100 commented Dec 15, 2019

@meijieru were you able to replicate the TF result in the end? No matter what I tried with RMSProp I could not get it to work well in PyTorch. It does work pretty well in TF though.

@meijieru
Copy link
Contributor Author

Yeah.

@vincentqb vincentqb self-assigned this Dec 18, 2019
@vincentqb
Copy link
Contributor

Closing since the documentation has been updated.

@bonlime
Copy link

bonlime commented Jun 3, 2020

@vincentqb @zou3519
Hi,
May I ask is there any particular reason the current implementation of RMSprop adds epsilon after sqrt?
From my experience training with default PyTorch RMSprop + FP16 (using apex) tends to produce nan during backprop. I'm sure @rwightman also experienced such issues. Training EfficientNets, for example, is impossible.

By moving epsilon inside sqrt everything works like a charm. There are also mathematical proofs that adding eps after sqrt is not enough to prevent overflow. Check this paper at p.6 bottom right. The topic of the paper is slightly different but their conclusion is also applicable here.

According to this issue other people also had problems with RMSprop behavior in PyTorch. As far as I can see you want the current implementation to be consistent with Adam / AdamW but they also tend to produce nan's!
It's a very tricky issue and it's hard to show that proposed implementation is superior to the current one by tests. But I know a bunch of people who successfully used adjusted versions of Adam / RMSprop to avoid nans

@vincentqb
Copy link
Contributor

vincentqb commented Jun 3, 2020

The main reason the epsilon is as it is right now is due to backward compatibility (and consistency with others like Adam/AdamW), and we cannot change the default behavior, at least not without a lot of warnings and time. We also try to keep the optimizer as lightweight as possible so users can modify and experiment with them more easily.

That being said, this is a topic that comes up often (e.g. #32545), so I'm open to having an alternative available such as you mentioned and offered in #23807 (which would need to be applied to Adam/AdamW also). What would be other alternatives we could do?

@bonlime
Copy link

bonlime commented Jun 3, 2020

@vincentqb
It looks like that for such a mature framework any changes in default behavior are slow and difficult. The reason why I even started this discussion is that in out Russian-speaking Slack (ODS) questions regarding nan's in loss with FP16 appear regularly and often moving eps inside sqrt helps.

I do agree that adding eps_inside_sqrt is not the best. I would suggest adding at least a short note about this issue to optimizer's doc string

@vincentqb
Copy link
Contributor

I do agree that adding eps_inside_sqrt is not the best. I would suggest adding at least a short note about this issue to optimizer's doc string

Is this what you mean #26735 ?

@bonlime
Copy link

bonlime commented Jun 4, 2020

I saw #26735 I think it's not enough. Maybe adding a sentence about possible overflow in FP16 training would be better

@vincentqb
Copy link
Contributor

If overflow happens in a systematic fashion, can you point me to the issue number in this case?

@Miffyli
Copy link

Miffyli commented Jul 23, 2020

Another issue related to TF vs. PyTorch rmsprop implementation. Fixing epsilon alone did not work for us:

Tensorflow initializes squared-grad accumulator to ones, while PyTorch initializes to zeros. In our experiments (A2C) we found PyTorch variant to learn faster in task but never converged to optimal policy, while TF version learns steadily and reliably.

PyTorch used to initialize to ones but it was changed years ago without much discussion (#485). I realize there is no golden-standard for rmsprop and you might not want to change this, but I believe TF version would be stabler with smaller initial gradient updates.

@rwightman
Copy link

Since I'm cc'd on this thread, I've had great success with my variant of RMSProp that tries to stay true to the TF version. I've trained quite a number of models with excellent results and so have quite a few others. Trying similar hparams with the PyTorch RMSProp results in unstable training and often immediate blow ups in training, it's basically not usable in my trials and I've never managed acceptable results.

There are 3 main differences:

I also tried changing a few order of ops to closer match TF but I doubt there was any impact whatsoever.

@rwightman
Copy link

The third one, the way the LR is applied to the update, is interesting but not often brought up. In steady state (of LR) the TF and PyTorch impl are equivalent, however they are not when LR changes, the TF version smooths the transition. Interestingly, many LR schedules used with rmsprop by some Google research teams change the LR quite frequently, they often have per step or per-epoch warmup ramps and then LR decay steps every 1-3 epochs. So this difference would have an impact.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants