-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Closed
Labels
module: onnxRelated to torch.onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
ONNX export fails for BatchNormalization with track_running_stats=False
Setting track_running_stats=False of BatchNormalization layer causes ONNX export to fail.
To Reproduce
- Example code:
` class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.cov1=nn.Conv2d(3, 32, 3)
self.bn1=nn.BatchNorm2d(32, track_running_stats=False)
self.relu1 =nn.ReLU()
def forward(self, x):
h = self.cov1(x)
h = self.bn1(h)
h = self.relu1(h)
return h
model = Model()
model.eval()
x = torch.randn(1, 3, 64, 64, requires_grad=True)
torch_outs = model(x)
torch.onnx.export(model,
x,
"model.onnx",
export_params=True,
opset_version=12,
do_constant_folding=True,
example_outputs=torch_outs,
input_names = ['input'],
output_names = ['output'],
verbose=True)`
- Output:
RuntimeError: Node (BatchNormalization_1)'s input 3 is marked single but has an empty string in the graph
==> Context: Bad node spec: input: "23" input: "bn1.weight" input: "bn1.bias" input: "" input: "" output: "26" name: "BatchNormalization_1" op_type: "BatchNormalization" attribute { name: "epsilon" f: 1e-05 type: FLOAT } attribute { name: "momentum" f: 0.9 type: FLOAT }
Expected behavior
Model converts without throwing an error.
Environment
- PyTorch Version (e.g., 1.0): 1.6
- OS (e.g., Linux): Linux Ubuntu 18.04
- How you installed PyTorch (
conda
,pip
, source): pip - Python version: 3.6.9
- CUDA/cuDNN version: 10.2
Additional context
Error happens only for track_running_stats=False, setting it to True removes the error.
cc @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof
Metadata
Metadata
Assignees
Labels
module: onnxRelated to torch.onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module