Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to update parameters using state_dict in AveragedModel #65495

Conversation

prabhat00155
Copy link
Contributor

While implementing EMA(which extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in EMA class. This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Sep 22, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 2448b61 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



1 job timed out:

  • pytorch_linux_xenial_py3_clang7_asan_test2

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot
Copy link
Contributor

@prabhat00155 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @prabhat00155.

I agree that this is very important for models with layers contain buffers such as BatchNorm. Not averaging all the parameters leads to lower performance.

Due to this missing feature, the community has build multiple implementations to get around the problem [1, 2, 3]. So this is definitely worth fixing.

Below I left a comment for your attention, please let me know what you think. Also we probably need to add a unit-test that covers the new code.

torch/optim/swa_utils.py Outdated Show resolved Hide resolved
@facebook-github-bot
Copy link
Contributor

@prabhat00155 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @prabhat00155.

I left one comment for your consideration but it's not blocking.

test/test_optim.py Show resolved Hide resolved
torch/optim/swa_utils.py Outdated Show resolved Hide resolved
@facebook-github-bot
Copy link
Contributor

@prabhat00155 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@prabhat00155 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a design discussion for this?
There can be objects that are not Tensors in the state_dict so I am not convinced this is going to work well.

@@ -96,12 +98,15 @@ def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
self.use_state_dict = mode == 'state_dict'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks very britle if the user pass another string as input.

@@ -26,6 +26,8 @@ class AveragedModel(Module):
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)
mode (str, optional): whether to use parameters or state_dict for update
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should document what are the valid values for this argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters and state_dict are valid values here as stated in line 29. Although, anything passed other than state_dict defaults to parameters.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho yes, what I meant is that this is missing " to make it clear that these are strings.
It's a nit though.

@datumbox
Copy link
Contributor

datumbox commented Sep 29, 2021

@albanD Thanks for the feedback. The PR contains links to the original issues/discussions at Vision but let me summarize the details here and make the post self-contained.

Previous implementation shortcomings

PyTorch's AveragedModel class didn't handle properly the buffers during model updates. This caused issues with many normalization layers (such as BatchNorm2d) that led to a significant decrease in accuracy (happy to provide additional numerical evidence; this is a very well-known fact in the research community). Another unfortunate detail of the previous implementation was that the initial values of the buffers were copied due to the deepcopy call, but were not subsequently updated by the update_parameters() method, leading to discrepancies:

def __init__(self, model, device=None, avg_fn=None):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)

Duplicate implementations

Due to the above, an overwhelming number of libraries on the PyTorch ecosystem (both FB-backed and community led) are currently not using PyTorch's AveragedModel class. Instead they reimplement the functionality on their own. Here are a few libraries that do this along with links to their implementations:

I strongly believe that improving the current implementation of the AveragedModel will bring significant benefits to PyTorch, the downstream libraries and the community. PyTorch's class will no longer be ignored, the downstream libraries won't have to reinvent the wheel and the community will get a utility that works off the shelf and can boost significantly their accuracy.

Different options

There are two generally accepted ways to handle the buffers:

  1. Average their values across checkpoints (like we do for parameters)
  2. Set their values directly to the latest checkpoint

There are pros/cons in both but the vast majority of the above implementations follow option 1. Implementing option 2 is also possible though less popular.

The approach of this PR

@prabhat00155's PR introduces a new mode parameter on the API that will allow us to support multiple options on the future. Though many people consider the current behaviour of the class a bug, changing the original behaviour now would be too drastic and would cause BC issues. This is why we decided to set the default value of mode in a way that maintains the original functionality.

Note that after merging this PR, the majority of the above implementations are now able to switch to using PyTorch core's AveragedModel. TorchVision plans to make the switch ASAP, when this change appears on the nightly.

Is it safe to use Module's state_dict()?

I believe it is safe. The vast majority of the linked implementations use state_dict() for this type of averaging. As a result this approach is very well-tested across a plethora of models and ML tasks.

Moreover note that the declared return type of the Module.state_dict() of the current PyTorch implementation indicates that it should return a dictionary with strings and Tensors:

T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor])
@overload
def state_dict(self, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
...
# TODO: Remove string escape once Python-3.6 no longer supported
# See https://github.com/python/mypy/issues/6904#issuecomment-496207426
@overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> 'OrderedDict[str, Tensor]':

Improving the new implementation

I think your comments on the code of this implementations are fair and worth addressing. We will assess the options carefully and follow up with another PR. Shall we add you as a reviewer to ensure all is well prior merging?

Finally please note that the TorchVision team plans to contribute more to Core on the future and we are happy to adjust to the processes that you use. Could you please clarify who is the POC for these utils? I ask because when we opened the PR a week ago, we tried git blaming but it didn't show an obvious person to include from your side.

Let me know what you think. Thanks!

@albanD
Copy link
Collaborator

albanD commented Sep 29, 2021

Thanks for the detailed comment @datumbox !

I wasn't trying to say that the change doesn't make sense. Just was wondering if there was an issues associated with this on pytorch/pytorch side.

Moreover note that the declared return type of the Module.state_dict()

That's a great point. cc @jbschlosser should we change this signature now that we have *_extra_state() functions that can return anything?

I am still wondering if we wouldn't prefer to do always parameters and optionally buffers. That would avoid this issue altogether. And that would make the arg in the API a boolean which is easier to handle (not magic string matching).
But I guess that's too late.

Could you please clarify who is the POC for these utils?

torch.optim has gone back into being maintained by the core team as a whole I'm afraid as our previous POC is not working on this anymore. You can have me as a POC if you need and I can route the review internally.

For these utils functions in optim, I am not sure we really want to grow them long term. These were adding as utilities for the swa optimizer but not really designed to be used in a standalone way.
If you plan to invest of some of the things here, we most likely want to move it somewhere else where there is better ownership. What do you think?

@datumbox
Copy link
Contributor

@albanD No worries at all. That gave us the opportunity to document things properly on Core side. Sorry for bringing this PR without the necessary documentation in the first place. :)

I am still wondering if we wouldn't prefer to do always parameters and optionally buffers.

If indeed the state_dict() is no longer guaranteed to return Tensors, we could make an internal change to handle params and buffers independently as you suggest. We will just need to come up with a better name for the mode value.

make the arg in the API a boolean

This will work only if we are certain we will never add option 2 listed above. I'm not saying that our team wants to add it now, just that the string arg gives us the option to do it on the future.

But I guess that's too late.

We can do a quick iteration and fix things if need be. We will just need to cherrypick them and put them on the release. So we kind of need to do it fast. If you can help with the reviews and clarify the design questions that popped up, we can do the coding part quickly.

If you plan to invest of some of the things here, we most likely want to move it somewhere else where there is better ownership.

Makes sense, let's cross that bridge when we get there. For this round of review I'll have you as POC as you suggested.

@malfet
Copy link
Contributor

malfet commented Sep 29, 2021

make the arg in the API a boolean

This will work only if we are certain we will never add option 2 listed above. I'm not saying that our team wants to add it now, just that the string arg gives us the option to do it on the future.

Not sure if this would be something completely out of ordinary, but can we use enum here?

@albanD
Copy link
Collaborator

albanD commented Sep 29, 2021

This will work only if we are certain we will never add option 2 listed above

Option 2 seems orthogonal no? One is wether or not your consider buffers and the other would be if you average or reset the values ?
But yes, if you need more than two options, I personally feel like enums are fine but that is not something that is broadly used. So checking that you get a valid string would be enough here if you prefer to go with a string.

@datumbox
Copy link
Contributor

Option 2 seems orthogonal no?

Yes, option 1 and option 2 are not combined together. But you could support both strategies on the same class.

Not sure if this would be something completely out of ordinary, but can we use enum here?

At TorchVision we started showing some love for enums. :) I haven't seen this pattern used much on core though.

Concerning next steps, we could:

  1. Do better validation of mode and throw the right errors to the user: Added option to update parameters using state_dict in AveragedModel #65495 (comment)
  2. Document better the available options: Added option to update parameters using state_dict in AveragedModel #65495 (comment)

Is there anything else we would like to do for this release?

@albanD
Copy link
Collaborator

albanD commented Sep 29, 2021

That sounds good for me! Thanks for taking the time to do it.

But you could support both strategies on the same class.

Yes, but it is not clear to me if it should be 2 boolean arguments or a single string/enum argument.

facebook-github-bot pushed a commit that referenced this pull request Oct 3, 2021
Summary:
Discussion: #65495 (comment)

Pull Request resolved: #65921

Reviewed By: albanD

Differential Revision: D31310105

Pulled By: prabhat00155

fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
prabhat00155 added a commit to prabhat00155/pytorch that referenced this pull request Oct 5, 2021
…ytorch#65495)

Summary:
While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

Pull Request resolved: pytorch#65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
(cherry picked from commit 2ea724b)
prabhat00155 added a commit to prabhat00155/pytorch that referenced this pull request Oct 5, 2021
Summary:
Discussion: pytorch#65495 (comment)

Pull Request resolved: pytorch#65921

Reviewed By: albanD

Differential Revision: D31310105

Pulled By: prabhat00155

fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
(cherry picked from commit c7748fc)
malfet pushed a commit that referenced this pull request Oct 6, 2021
…65495) (#65755)

* Added option to update parameters using state_dict in AveragedModel (#65495)

Summary:
While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

Pull Request resolved: #65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
(cherry picked from commit 2ea724b)

* Added validation of mode parameter in AveragedModel (#65921)

Summary:
Discussion: #65495 (comment)

Pull Request resolved: #65921

Reviewed By: albanD

Differential Revision: D31310105

Pulled By: prabhat00155

fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
(cherry picked from commit c7748fc)
prabhat00155 added a commit that referenced this pull request Oct 8, 2021
malfet pushed a commit that referenced this pull request Oct 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants