Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batchnorms force set to training mode on torch.onnx.export when running stats are None #75252

Closed
8scarlet8 opened this issue Apr 5, 2022 · 17 comments
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@8scarlet8
Copy link

🐛 Describe the bug

When converting PyTorch model to .onnx it assumes that batchnorm layers are in training mode if track_running_stats=False even though layers clearly have training attribute set to False. We can reproduce this by setting module.running_var = None and module.running_mean = None or by creating new model with nn.BatchNorm2d(channels[0], track_running_stats=True). Here I will provide basic conversion example with resnet50, where I forcibly set running stats to None. If needed, I can provide example where model has batchnorms initialized with track_running_stats=False.

This causes converter to wrongly assume that layers are in training mode which prevents further loading with openvino backend. Same thing happens when we convert model to TorchScript with tracing or scripting in advance. This happened to me on PyTorch 1.11.0 in my local testing environment and with PyTorch 1.10.0 on Google Colab. If needed I can reproduce it on Colab with PyTorch 1.11.0.

Is this intentional, so that batchnorms should always have running_stats or is this a bug?

Here is conversion example where track_running_stats=True and conversion goes smoothly, loading with openvino backend.

import torch
import torch.nn as nn
import torchvision.models as models
from openvino.inference_engine import IECore

model = models.resnet50(pretrained=True)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
input_names = [ "actual_input" ]
output_names = [ "output" ]
torch.onnx.export(model, 
              dummy_input,
              "resnet50.onnx",
              verbose=False,
              input_names=input_names,
              output_names=output_names,
              export_params=True,
              )
ie = IECore()
net = ie.read_network('resnet50.onnx')
model = ie.load_network(network=net, device_name="CPU")

Here is conversion example where batchnorms have running stats set to None (as it happens with track_running_stats=False).

import torch
import torch.nn as nn
import torchvision.models as models
from openvino.inference_engine import IECore

model = models.resnet50(pretrained=True)
model.eval()

for module in model.modules():
        if isinstance(module, nn.modules.batchnorm._BatchNorm):
            module.track_running_stats = False
            module.running_var = None
            module.running_mean = None
            
dummy_input = torch.randn(1, 3, 224, 224)
input_names = [ "actual_input" ]
output_names = [ "output" ]
torch.onnx.export(model, 
              dummy_input,
              "resnet50.onnx",
              verbose=False,
              input_names=input_names,
              output_names=output_names,
              export_params=True,
              )
ie = IECore()
net = ie.read_network('resnet50.onnx')
model = ie.load_network(network=net, device_name="CPU")

During conversion it gives following warning:

/usr/local/lib/python3.7/dist-packages/torch/onnx/symbolic_helper.py:773: UserWarning: ONNX export mode is set to inference mode, but operator batch_norm is set to training  mode. The operators will be exported in training , as specified by the functional operator.
  op_mode + ", as specified by the functional operator.")

And throws error when attempting to load in openvino backend:

RuntimeError: Check '(node.get_outputs_size() == 1)' failed at frontends/onnx/frontend/src/op/batch_norm.cpp:67:
While validating ONNX node '<Node(BatchNormalization): BatchNormalization_9>':
Training mode of BatchNormalization is not supported.

Versions

PyTorch version: 1.10.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0
Libc version: glibc-2.26

Python version: 3.7.13 (default, Mar 16 2022, 17:37:17) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: False
CUDA runtime version: 11.1.105
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.10.0+cu111
[pip3] torchaudio==0.10.0+cu111
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.11.0
[pip3] torchvision==0.11.1+cu111
[pip3] openvino==2022.1.0
[conda] Could not collect

@bdhirsh bdhirsh added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 5, 2022
@BowenBao
Copy link
Collaborator

BowenBao commented Apr 19, 2022

Thanks @8scarlet8, unfortunately I couldn't think of an easy fix. It turns out PyTorch "considers" batchnorm as training, when both running stats are None

bn_training = (self.running_mean is None) and (self.running_var is None)

return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean
if not self.training or self.track_running_stats
else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)

So from kernel level, what the onnx exporter observe will be batchnorm with training=True. There isn't any extra information that can help differentiate it from real training.
- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor

To quickly unblock the issue, one possible solution is to post process the exported onnx graph to remove the unused outputs from BatchNorm.

@8scarlet8
Copy link
Author

Thank you for your reply, @BowenBao.

I've managed to solve the problem by initializing model with all batchnorms in track_running_stats=True mode, then loading model weights and converting to onnx.

Your solution seems plausible as well, in case we need to fix already converted model for example.

What bothers me is that this problem seems counter intuitive to me. Please correct me, if I'm wrong, but existence of running_track_stats tensors has nothing to do with batch norms being in training mode. It's just that problem seemed hard to debug, which may cause trouble for other PyTorch users. So if such behaviour is not intended, I suggest either of two possible solutions:

  1. Informing users that batch norms are converted in training mode due to absence of track_running_stats tensor if they try to convert in eval mode. We can throw more informative warning in addition to one that informs about conversion in training mode.
  2. If possible - try to fix the issue by initializing dummy track_running_stats tensors when attempting to convert in eval mode and such tensors are not present in batch norms. Maybe even try to fix core issue of why converter assumes training mode of batch norm.

@garymm garymm added the onnx-triaged triaged by ONNX team label May 4, 2022
@aweinmann
Copy link

The same issue also applies to InstanceNorm

@akiyamasho
Copy link

+1 on @aweinmann 's comment. it also happens to InstanceNorm

torch/onnx/symbolic_helper.py:773: UserWarning: ONNX export mode is set to inference mode, but operator instance_norm is set to training  mode. The operators will be exported in training , as specified by the functional operator.

@youjin-c
Copy link

youjin-c commented Aug 16, 2022

+1
I have the same issue, and I guess this is a reason my onnx runtime (Lens Studio) cannot load the model.

UserWarning: ONNX export mode is set to inference mode, but operator instance_norm is set to training  mode. The operators will be exported in training , as specified by the functional operator.
  + ", as specified by the functional operator."

@titaiwangms
Copy link
Collaborator

We’ve gone ahead and closed this issue because it is stale.
If you still believe this issue is relevant, please feel free to reopen the issue and we will triage it as necessary. Please specify in a comment any updated information you may have so that we can address it effectively. We encourage you to try the latest pytorch-preview (nightly) version to see if it has resolved the issue.

Thanks,
ONNX Converter team

@titaiwangms titaiwangms closed this as not planned Won't fix, can't repro, duplicate, stale Oct 24, 2022
@gedeon1310
Copy link

I managed to remove this warning using this (inelegant) method:

Loop over network modules and explicitly set training to False before exporting to onnx:

for m in net.modules(): if 'instancenorm' in m.__class__.__name__.lower(): m.train(False)

This removed the warning when exporting to Onnx.

Hope this helps,

@lminer
Copy link

lminer commented Jan 12, 2023

I'm still getting this warning. Can I ignore it?

@luisfmnunes
Copy link

Although I'm getting the same warning for Instance Normalization, the output result of the ONNX model is pretty much equivallent (same input from PyTorch model).

image

@Bea07
Copy link

Bea07 commented Feb 27, 2023

I've managed to solve the problem by initializing model with all batchnorms in track_running_stats=True mode, then loading model weights and converting to onnx.

@8scarlet8 could you, please, provide an example of the code? it should be done after the model.eval() function right?

If possible - try to fix the issue by initializing dummy track_running_stats tensors when attempting to convert in eval mode and such tensors are not present in batch norms.

Especially, I would like to know how did you manage to do that?

@Bea07
Copy link

Bea07 commented Mar 10, 2023

To quickly unblock the issue, one possible solution is to post process the exported onnx graph to remove the unused outputs from BatchNorm.

@BowenBao could you suggest a way to do it?

@ThibaultGROUEIX
Copy link

Hi folks :) What is the solution to this problem? @gedeon1310's hack did not work for me.

@Merealtea
Copy link

Merealtea commented Aug 26, 2023

Hi, everyone! Did anyone actually solve this problem? I added this code

for m in net.modules(): if 'instancenorm' in m.__class__.__name__.lower(): m.train(False)

after I loaded the state_dict, the warning still appeared as this fellow showed when exporting to onnx model.

Although I'm getting the same warning for Instance Normalization, the output result of the ONNX model is pretty much equivallent (same input from PyTorch model).

image

Though it could successfully export to onnx model, another cudnn error happened which shows that the instancenorm was still in training mode when I transforred it to tensorRT model for succession work.

So if anyone who has better solution, please give us a hand to get rid of this problem. Thx

@michal-kierzynka
Copy link

A workaround (as already mentioned by @8scarlet8) is to initialize respective operator explicite with track_running_stats=True (default is False):

nn.InstanceNorm2d(..., track_running_stats=True, ...)

The operator in this case will track the running mean and variance, as described in the documentation: https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html

@M1keZulu
Copy link

M1keZulu commented Sep 10, 2023

None of the workarounds are working for me. No matter what I do, its still exporting it in training mode and I have no clue what to do.

@315386775
Copy link

how to solove it. ONNX export mode is set to inference mode, but operator instance_norm is set to training mode. The operators will be exported in training , as specified by the functional operator.

@Merealtea
Copy link

Merealtea commented Sep 27, 2023 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests