New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added option to update parameters using state_dict in AveragedModel #65495
Added option to update parameters using state_dict in AveragedModel #65495
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 2448b61 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
1 job timed out:
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
@prabhat00155 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @prabhat00155.
I agree that this is very important for models with layers contain buffers such as BatchNorm. Not averaging all the parameters leads to lower performance.
Due to this missing feature, the community has build multiple implementations to get around the problem [1, 2, 3]. So this is definitely worth fixing.
Below I left a comment for your attention, please let me know what you think. Also we probably need to add a unit-test that covers the new code.
@prabhat00155 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @prabhat00155.
I left one comment for your consideration but it's not blocking.
@prabhat00155 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
@prabhat00155 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a design discussion for this?
There can be objects that are not Tensors in the state_dict so I am not convinced this is going to work well.
@@ -96,12 +98,15 @@ def avg_fn(averaged_model_parameter, model_parameter, num_averaged): | |||
return averaged_model_parameter + \ | |||
(model_parameter - averaged_model_parameter) / (num_averaged + 1) | |||
self.avg_fn = avg_fn | |||
self.use_state_dict = mode == 'state_dict' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks very britle if the user pass another string as input.
@@ -26,6 +26,8 @@ class AveragedModel(Module): | |||
:class:`AveragedModel` parameter, the current value of :attr:`model` | |||
parameter and the number of models already averaged; if None, | |||
equally weighted average is used (default: None) | |||
mode (str, optional): whether to use parameters or state_dict for update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should document what are the valid values for this argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parameters and state_dict are valid values here as stated in line 29. Although, anything passed other than state_dict defaults to parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho yes, what I meant is that this is missing "
to make it clear that these are strings.
It's a nit though.
@albanD Thanks for the feedback. The PR contains links to the original issues/discussions at Vision but let me summarize the details here and make the post self-contained. Previous implementation shortcomingsPyTorch's pytorch/torch/optim/swa_utils.py Lines 87 to 89 in 27135f8
Duplicate implementationsDue to the above, an overwhelming number of libraries on the PyTorch ecosystem (both FB-backed and community led) are currently not using PyTorch's
I strongly believe that improving the current implementation of the Different optionsThere are two generally accepted ways to handle the buffers:
There are pros/cons in both but the vast majority of the above implementations follow option 1. Implementing option 2 is also possible though less popular. The approach of this PR@prabhat00155's PR introduces a new Note that after merging this PR, the majority of the above implementations are now able to switch to using PyTorch core's Is it safe to use Module's state_dict()?I believe it is safe. The vast majority of the linked implementations use Moreover note that the declared return type of the pytorch/torch/nn/modules/module.py Lines 1278 to 1287 in b80bdcc
Improving the new implementationI think your comments on the code of this implementations are fair and worth addressing. We will assess the options carefully and follow up with another PR. Shall we add you as a reviewer to ensure all is well prior merging? Finally please note that the TorchVision team plans to contribute more to Core on the future and we are happy to adjust to the processes that you use. Could you please clarify who is the POC for these utils? I ask because when we opened the PR a week ago, we tried git blaming but it didn't show an obvious person to include from your side. Let me know what you think. Thanks! |
Thanks for the detailed comment @datumbox ! I wasn't trying to say that the change doesn't make sense. Just was wondering if there was an issues associated with this on pytorch/pytorch side.
That's a great point. cc @jbschlosser should we change this signature now that we have I am still wondering if we wouldn't prefer to do always parameters and optionally buffers. That would avoid this issue altogether. And that would make the arg in the API a boolean which is easier to handle (not magic string matching).
For these utils functions in optim, I am not sure we really want to grow them long term. These were adding as utilities for the swa optimizer but not really designed to be used in a standalone way. |
@albanD No worries at all. That gave us the opportunity to document things properly on Core side. Sorry for bringing this PR without the necessary documentation in the first place. :)
If indeed the
This will work only if we are certain we will never add option 2 listed above. I'm not saying that our team wants to add it now, just that the string arg gives us the option to do it on the future.
We can do a quick iteration and fix things if need be. We will just need to cherrypick them and put them on the release. So we kind of need to do it fast. If you can help with the reviews and clarify the design questions that popped up, we can do the coding part quickly.
Makes sense, let's cross that bridge when we get there. For this round of review I'll have you as POC as you suggested. |
Not sure if this would be something completely out of ordinary, but can we use |
Option 2 seems orthogonal no? One is wether or not your consider buffers and the other would be if you average or reset the values ? |
Yes, option 1 and option 2 are not combined together. But you could support both strategies on the same class.
At TorchVision we started showing some love for enums. :) I haven't seen this pattern used much on core though. Concerning next steps, we could:
Is there anything else we would like to do for this release? |
That sounds good for me! Thanks for taking the time to do it.
Yes, but it is not clear to me if it should be 2 boolean arguments or a single string/enum argument. |
Summary: Discussion: #65495 (comment) Pull Request resolved: #65921 Reviewed By: albanD Differential Revision: D31310105 Pulled By: prabhat00155 fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
…ytorch#65495) Summary: While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: pytorch/vision#4406 (review) Pull Request resolved: pytorch#65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2 (cherry picked from commit 2ea724b)
Summary: Discussion: pytorch#65495 (comment) Pull Request resolved: pytorch#65921 Reviewed By: albanD Differential Revision: D31310105 Pulled By: prabhat00155 fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3 (cherry picked from commit c7748fc)
…65495) (#65755) * Added option to update parameters using state_dict in AveragedModel (#65495) Summary: While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: pytorch/vision#4406 (review) Pull Request resolved: #65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2 (cherry picked from commit 2ea724b) * Added validation of mode parameter in AveragedModel (#65921) Summary: Discussion: #65495 (comment) Pull Request resolved: #65921 Reviewed By: albanD Differential Revision: D31310105 Pulled By: prabhat00155 fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3 (cherry picked from commit c7748fc)
While implementing EMA(which extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in EMA class. This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.
Discussion: pytorch/vision#4406 (review)