-
Notifications
You must be signed in to change notification settings - Fork 64
BatchNorm training=True in some Timm classes #1338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I wonder if it would make sense to replace when to false after we obtain the model? |
Yea I think someone brought up this point before: unless we anticipate use cases for training ops, onnx-rewriter can potentially just set all the training attributes it sees to False? But then that's assuming the user will always run the onnx-rewriter a posteriori. |
I wonder why training=True in the first place? Was pytorch trying to create some kind of training behavior even when eval() was called on the model? |
This can be integrated into exporter since it operates on standard onnx domains. In onnx, training only affects whether bn emits extra outputs. cc @gramalingam I think that's our conclusion last time. So if these outputs are unused, it is safe to remove the outputs and flip the flag in the node. |
Previously we resolved the issue (#1262) where instance norm in pytorch was decomposing into batch norm with training mode set to True.
I've also encountered some issues with BN set to True in certain timm classes. If you look in this file https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/std_conv.py, all calls to
F.batch_norm()
havetraining=True
. Any thoughts on what would be the best solution here? I'm wondering if there's a way to always set training to False during tracing.The text was updated successfully, but these errors were encountered: