New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve use_state_dict in AveragedModel #66686
Comments
@gchanan suggested not to use |
A boolean flag will work if we plan to only support buffer averaging. As discussed at #65495 (comment) (section "Different options"), there are alternative approaches on how to handle buffers. One of them is just to set them to their latest values. I'm not suggesting that we should implement this now, but just noting that if we want to keep this option, we probably need to maintain some kind of |
I think there's some lack of clarity going on here and it would be good if we discussed some specific simple examples. For example, the description in #65495 begs the solution:
But EMA operates on a model -- what's a model that doesn't work? Will all models work if we make this change? Somewhere (I can't find it now) @datumbox mentioned that normalization doesn't work, which makes intuitive sense since statistics are tracked separately from the parameters. But the code in question already has special handling for (at least) batch norm: pytorch/torch/optim/swa_utils.py Lines 115 to 140 in 27135f8
How does that code fit into this story? |
@gchanan Thanks for your comments.
Models with layers that contain buffers (like BN) have reduced accuracy. See this reference.
Yes, all models will work if we make the change. I have provided some evidence on how the proposed update works in other libraries at #65495 (comment). Please let me know if you require additional evidence and if yes what this should be (metrics? code snippets for corner-cases? something else?).
You are right to say that there is a method provided for BN. So why do we want to take additional steps? It's because:
|
I think what's missing is:
|
Do you have an example of such a buffer within
I think you have a good understanding of the use-case we want to cover and why we do it. Do you have a recommendation on how to achieve it? Or perhaps you prefer to rollback the feature and handle this on a case-by-case basis? |
Fixed by #71763. |
Summary: Fixes [pytorch/pytorch#66686 Pull Request resolved: pytorch/pytorch#71763 Reviewed By: anjali411 Differential Revision: D33770907 Pulled By: prabhat00155 fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc (cherry picked from commit a15898b)
Summary: Fixes [pytorch/pytorch#66686 Pull Request resolved: pytorch/pytorch#71763 Reviewed By: anjali411 Differential Revision: D33770907 Pulled By: prabhat00155 fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc (cherry picked from commit a15898b)
Summary: Fixes [pytorch/pytorch#66686 Pull Request resolved: pytorch/pytorch#71763 Reviewed By: anjali411 Differential Revision: D33770907 Pulled By: prabhat00155 fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc (cherry picked from commit a15898b)
Summary: Fixes [pytorch/pytorch#66686 Pull Request resolved: pytorch/pytorch#71763 Reviewed By: anjali411 Differential Revision: D33770907 Pulled By: prabhat00155 fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc (cherry picked from commit a15898b)
use_state_dict was added in AveragedModel in #65495 and #65921.
This needs to be further improved as described here. This issue has been created to keep track of this work item.
cc @vincentqb @jbschlosser @albanD @datumbox
The text was updated successfully, but these errors were encountered: