Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve use_state_dict in AveragedModel #66686

Closed
prabhat00155 opened this issue Oct 15, 2021 · 7 comments
Closed

Improve use_state_dict in AveragedModel #66686

prabhat00155 opened this issue Oct 15, 2021 · 7 comments
Assignees
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@prabhat00155
Copy link
Contributor

prabhat00155 commented Oct 15, 2021

use_state_dict was added in AveragedModel in #65495 and #65921.
This needs to be further improved as described here. This issue has been created to keep track of this work item.

cc @vincentqb @jbschlosser @albanD @datumbox

@prabhat00155 prabhat00155 self-assigned this Oct 15, 2021
@H-Huang H-Huang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 15, 2021
@albanD albanD added the module: optimizer Related to torch.optim label Oct 22, 2021
@prabhat00155
Copy link
Contributor Author

@gchanan suggested not to use state_dict as a mode. Would something like with_buffers or include_buffers make more sense?
Also, @albanD mentioned that there could be objects in state_dict that won't be tensors. Could we handle this case by ignoring non-tensor values?

@datumbox
Copy link
Contributor

something like with_buffers or include_buffers

A boolean flag will work if we plan to only support buffer averaging. As discussed at #65495 (comment) (section "Different options"), there are alternative approaches on how to handle buffers. One of them is just to set them to their latest values. I'm not suggesting that we should implement this now, but just noting that if we want to keep this option, we probably need to maintain some kind of mode attribute (string/enum).

@gchanan
Copy link
Contributor

gchanan commented Oct 26, 2021

I think there's some lack of clarity going on here and it would be good if we discussed some specific simple examples.

For example, the description in #65495 begs the solution:

While implementing EMA(which extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in EMA class. This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

But EMA operates on a model -- what's a model that doesn't work? Will all models work if we make this change?

Somewhere (I can't find it now) @datumbox mentioned that normalization doesn't work, which makes intuitive sense since statistics are tracked separately from the parameters. But the code in question already has special handling for (at least) batch norm:

@torch.no_grad()
def update_bn(loader, model, device=None):
r"""Updates BatchNorm running_mean, running_var buffers in the model.
It performs one pass over data in `loader` to estimate the activation
statistics for BatchNorm layers in the model.
Args:
loader (torch.utils.data.DataLoader): dataset loader to compute the
activation statistics on. Each data batch should be either a
tensor, or a list/tuple whose first element is a tensor
containing data.
model (torch.nn.Module): model for which we seek to update BatchNorm
statistics.
device (torch.device, optional): If set, data will be transferred to
:attr:`device` before being passed into :attr:`model`.
Example:
>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)
.. note::
The `update_bn` utility assumes that each data batch in :attr:`loader`
is either a tensor or a list or tuple of tensors; in the latter case it
is assumed that :meth:`model.forward()` should be called on the first
element of the list or tuple corresponding to the data batch.
"""

How does that code fit into this story?

@datumbox
Copy link
Contributor

datumbox commented Oct 27, 2021

@gchanan Thanks for your comments.

What's a model that doesn't work?

Models with layers that contain buffers (like BN) have reduced accuracy. See this reference.

Will all models work if we make this change?

Yes, all models will work if we make the change. I have provided some evidence on how the proposed update works in other libraries at #65495 (comment). Please let me know if you require additional evidence and if yes what this should be (metrics? code snippets for corner-cases? something else?).

But the code in question already has special handling for (at least) batch norm

You are right to say that there is a method provided for BN. So why do we want to take additional steps?

It's because:

  1. BNs are not the only layers with buffers, hence the model averaging should be able to handle those other layers.
  2. The update_bn() offers a 3rd alternative approach for fixing the issue for the BN layer. It requires a 2-step training process and further updates happen after training. Please note that this just updates the running means/variances without adjusting the actual weights of the Convolutions which means it's quite possible to cause a drift on the performance of the model and thus can't be applied in all cases.

@gchanan
Copy link
Contributor

gchanan commented Oct 27, 2021

I think what's missing is:

  1. A summary of this discussion in the docs. As a user coming in, I can't really make sense of what's going on currently. For example:
  • explain for normalization layers the tradeoff between using this and update_bn
  • acknowledge that what's in the buffers/state_dict can be basically anything; I think the argument you are making is that using this works well empirically for normalization layers where the statistics are kept in the buffers. But the buffers can represent any quantity where it might not make sense to apply averaging to. Does it make sense to write a sentence in the docs along these lines like "Empirical evidence has shown that applying the averaging to statistics in normalization layers increases accuracy (link to evidence). This may apply to other types of layers/buffers, but you may wish to empirically test."
  1. Resolve whether you actually want to look into the the buffers or state_dict. Because the reasoning for averaging the buffers isn't clear, it's also not clear for the state_dict. For example, the state_dict might in the future contain a version number for BC reasons. Does it make sense to average this?

@datumbox
Copy link
Contributor

But the buffers can represent any quantity where it might not make sense to apply averaging to.

Do you have an example of such a buffer within nn.Module or you are concerned about future usages? The state_dict averaging approach is currently used by a large number of downstream libraries and seemed to work for different tasks and models. I believe a common assumption is that buffers store stats about the model/layers which are not updated with gradients. If that's not true, then it's worth clarifying it on the documentation and perhaps reaching out to the downstream libraries with a recommendation on how to update their implementations.

Resolve whether you actually want to look into the the buffers or state_dict.

I think you have a good understanding of the use-case we want to cover and why we do it. Do you have a recommendation on how to achieve it? Or perhaps you prefer to rollback the feature and handle this on a case-by-case basis?

facebook-github-bot pushed a commit that referenced this issue Jan 26, 2022
Summary:
Fixes [#66686

Pull Request resolved: #71763

Reviewed By: anjali411

Differential Revision: D33770907

Pulled By: prabhat00155

fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc
pytorchmergebot pushed a commit that referenced this issue Jan 26, 2022
Summary:
Fixes [#66686

Pull Request resolved: #71763

Reviewed By: anjali411

Differential Revision: D33770907

Pulled By: prabhat00155

fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc
(cherry picked from commit a15898b)
@prabhat00155
Copy link
Contributor Author

Fixed by #71763.

cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 3, 2022
Summary:
Fixes [pytorch/pytorch#66686

Pull Request resolved: pytorch/pytorch#71763

Reviewed By: anjali411

Differential Revision: D33770907

Pulled By: prabhat00155

fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc
(cherry picked from commit a15898b)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 3, 2022
Summary:
Fixes [pytorch/pytorch#66686

Pull Request resolved: pytorch/pytorch#71763

Reviewed By: anjali411

Differential Revision: D33770907

Pulled By: prabhat00155

fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc
(cherry picked from commit a15898b)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Fixes [pytorch/pytorch#66686

Pull Request resolved: pytorch/pytorch#71763

Reviewed By: anjali411

Differential Revision: D33770907

Pulled By: prabhat00155

fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc
(cherry picked from commit a15898b)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Fixes [pytorch/pytorch#66686

Pull Request resolved: pytorch/pytorch#71763

Reviewed By: anjali411

Differential Revision: D33770907

Pulled By: prabhat00155

fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc
(cherry picked from commit a15898b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants