Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot export saved ScriptModule models to ONNX format #14869

Closed
markrogersjr opened this issue Dec 7, 2018 · 11 comments
Closed

Cannot export saved ScriptModule models to ONNX format #14869

markrogersjr opened this issue Dec 7, 2018 · 11 comments
Labels
module: onnx Related to torch.onnx oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@markrogersjr
Copy link

馃悰 Bug

I am unable to load a PyTorch model as a ScriptModule from a file and then export to ONNX. However, I am able to export a ScriptModule in memory directly to ONNX.

To Reproduce

Steps to reproduce the behavior:

The script

import torch
import torchvision
m = torchvision.models.resnet50()
x = torch.rand((1, 3, 224, 224))
f = 'model.pt'
torch.jit.save(m, f)
m = torch.jit.load(f)
torch.onnx.export(m, x, 'model.onnx')

yields AssertionError: example_outputs must be provided when exporting a ScriptModule. Replacing the last line with

torch.onnx._export(m, x, 'model.onnx', example_outputs=torch.rand((1, 1000)))

give me the following error:

Traceback (most recent call last):
  File "jit.py", line 25, in <module>
    torch.onnx._export(m, x, 'model.onnx', example_outputs=torch.rand((1,1000)))
  File "/home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/onnx/__init__.py", line 22, in _export
    return utils._export(*args, **kwargs)
  File "/home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/onnx/utils.py", line 281, in _export
    example_outputs, propagate)
  File "/home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/onnx/utils.py", line 227, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type)
  File "/home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/onnx/utils.py", line 155, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/onnx/__init__.py", line 52, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/onnx/utils.py", line 504, in _run_symbolic_function
    return fn(g, *inputs, **attrs)
  File "/home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/onnx/symbolic.py", line 89, in wrapper
    return fn(g, *args)
  File "/home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/onnx/symbolic.py", line 736, in batch_norm
    input_sizes = input.type().sizes()
RuntimeError: r ASSERT FAILED at /home/ubuntu/pytorch/aten/src/ATen/core/jit_type.h:128, please report a bug to PyTorch. (expect at /home/ubuntu/pytorch/aten/src/ATen/core/jit_type.h:128)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6a (0x7fbce5084a8a in /home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: std::shared_ptr<c10::CompleteTensorType> c10::Type::expect<c10::CompleteTensorType>() + 0x2a8 (0x7fbcd48cdad8 in /home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #2: <unknown function> + 0x362dab (0x7fbcd48a9dab in /home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #3: <unknown function> + 0x115d8d (0x7fbcd465cd8d in /home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #17: torch::jit::BlockToONNX(torch::jit::Block*, torch::jit::Block*, torch::onnx::OperatorExportTypes, std::unordered_map<torch::jit::Value*, torch::jit::Value*, std::hash<torch::jit::Value*>, std::equal_to<torch::jit::Value*>, std::allocator<std::pair<torch::jit::Value* const, torch::jit::Value*> > >) + 0x6a1 (0x7fbcd4883171 in /home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #18: torch::jit::ToONNX(std::shared_ptr<torch::jit::Graph>&, torch::onnx::OperatorExportTypes) + 0x2e7 (0x7fbcd4884987 in /home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #19: <unknown function> + 0x3393e2 (0x7fbcd48803e2 in /home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #20: <unknown function> + 0x33949e (0x7fbcd488049e in /home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #21: <unknown function> + 0x115d8d (0x7fbcd465cd8d in /home/ubuntu/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #44: __libc_start_main + 0xf0 (0x7fbceb388830 in /lib/x86_64-linux-gnu/libc.so.6)

Expected behavior

Successfully export to ONNX format

Environment

PyTorch version: 1.0.0a0+b5db6ac
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.12.2

Python version: 3.7
Is CUDA available: No
CUDA runtime version: 9.0.176
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/local/cuda-8.0/lib64/libcudnn.so.6.0.21
/usr/local/cuda-8.0/lib64/libcudnn_static.a
/usr/local/cuda-9.0/lib64/libcudnn.so.7.3.1
/usr/local/cuda-9.0/lib64/libcudnn_static.a
/usr/local/cuda-9.1/lib64/libcudnn.so.7.0.5
/usr/local/cuda-9.1/lib64/libcudnn_static.a
/usr/local/cuda-9.2/lib64/libcudnn.so.7.3.1
/usr/local/cuda-9.2/lib64/libcudnn_static.a

Versions of relevant libraries:
[pip] numpy (1.15.4)
[pip] torch (1.0.0a0+b5db6ac)
[pip] torchvision (0.2.1)
[conda] mkl 2019.1 144
[conda] mkl-include 2019.1 144
[conda] torch 1.0.0a0+b5db6ac
[conda] torchvision 0.2.1

Additional context

@pytorchbot pytorchbot added the module: onnx Related to torch.onnx label Dec 7, 2018
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 7, 2018
@singhblom
Copy link

I'm seeing a very error message from a different place.

  File "/home/martinsb/anaconda3/envs/v2f/lib/python3.6/site-packages/torch/onnx/symbolic.py", line 513, in softmax
    if len(input.type().sizes()) != dim + 1:
RuntimeError: r ASSERT FAILED at /opt/conda/conda-bld/pytorch_1544174967633/work/aten/src/ATen/core/jit_type.h:128, please report a bug to PyTorch. (expect at /opt/conda/conda-bld/pytorch_1544174967633/work/aten/src/ATen/core/jit_type.h:128)

It seems like input.type().sizes() is the culprit.

To reproduce your version I needed to modify your script slightly.

import torch
import torchvision
m = torchvision.models.resnet50()
x = torch.rand((1, 3, 224, 224))
traced_m = torch.jit.trace(m, x)
f = 'model.pt'
torch.jit.save(traced_m, f)
loaded_m = torch.jit.load(f)
#torch.onnx.export(loaded_m, x, 'model.onnx')
torch.onnx._export(loaded_m, x, 'model.onnx', example_outputs=torch.rand((1, 1000)))

Environment

Test both on

PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration: GPU 0: TITAN V
Nvidia driver version: 390.77
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.5.1.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.20
/usr/lib/x86_64-linux-gnu/libcudnn_static_v5.a

Versions of relevant libraries:
[pip] Could not collect
[conda] cuda90                    1.0                  h6433d27_0    pytorch
[conda] pytorch                   1.0.0           py3.6_cuda9.0.176_cudnn7.4.1_1    pytorch
[conda] pytorch-wavenet           2018.1                    <pip>
[conda] torch                     0.4.1                     <pip>
[conda] torchvision               0.2.1                     <pip>
[conda] torchvision               0.2.1                      py_2    pytorch

and on

Collecting environment information...
PyTorch version: 1.0.0.dev20181216
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration: GPU 0: TITAN V
Nvidia driver version: 390.77
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.5.1.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.20
/usr/lib/x86_64-linux-gnu/libcudnn_static_v5.a

Versions of relevant libraries:
[pip] Could not collect
[conda] blas                      1.0                         mkl
[conda] cuda90                    1.0                  h6433d27_0    pytorch
[conda] mkl                       2018.0.3                      1
[conda] mkl-service               1.1.2            py36h90e4bf4_5
[conda] mkl_fft                   1.0.6            py36h7dd41cf_0
[conda] mkl_random                1.0.1            py36h4414c95_1
[conda] pytorch-nightly           1.0.0.dev20181216 py3.6_cuda9.0.176_cudnn7.4.1_0    pytorch
[conda] pytorch-wavenet           2018.1                    <pip>
[conda] torchvision-nightly       0.2.1                     <pip>

@markrogersjr
Copy link
Author

markrogersjr commented Feb 6, 2019

Bump. Can someone please take a look? Error still occurs in latest version of PyTorch.

@markrogersjr
Copy link
Author

markrogersjr commented Feb 18, 2019

Update: I searched through the nightly builds and found that the bug disappears around 11/29/2018. I suspect that the bug was introduced in this PR. @zdevito can you please investigate? It may be related to the serialization/deserialization in import.cpp and export.cpp.

@markrogersjr
Copy link
Author

@singhblom Thanks for sharing. Yes, I also found an error thrown in torch/onnx/symbolic.py:_convolution(), which attempts to access weight.type().sizes().

@markrogersjr
Copy link
Author

Tagging reviewers of PR: @suo @houseroad @dzhulgakov

@dzhulgakov
Copy link
Collaborator

I think is because we erase shape information when saving script module. However, we should repopulate it again when doing onnx export that invokes tracing again. So it鈥檚 a legit bug.

cc @jamesr66a

@markrogersjr
Copy link
Author

Can I please get an ETA? @dzhulgakov @jamesr66a

@borisfom
Copy link
Contributor

@dzhulgakov : any chance for this to get fixed ?

@whitelok
Copy link
Contributor

this issue already fixed in latest commit 30bc19d

@bhosmer
Copy link
Contributor

bhosmer commented Dec 13, 2019

No longer an assert, but still errors on recent master:

import torch
import torchvision
m = torchvision.models.resnet50()
x = torch.rand((1, 3, 224, 224))
traced_m = torch.jit.trace(m, x)
f = 'model.pt'
torch.jit.save(traced_m, f)
loaded_m = torch.jit.load(f)
#torch.onnx.export(loaded_m, x, 'model.onnx')
torch.onnx._export(loaded_m, x, 'model.onnx', example_outputs=torch.rand((1, 1000)))
Traceback (most recent call last):
  File "scratch.py", line 10, in <module>
    torch.onnx._export(loaded_m, x, 'model.onnx', example_outputs=torch.rand((1, 1000)))
  File "/Users/bhosmer/dev/pt/torch/onnx/__init__.py", line 26, in _export
    result = utils._export(*args, **kwargs)
  File "/Users/bhosmer/dev/pt/torch/onnx/utils.py", line 466, in _export
    fixed_batch_size=fixed_batch_size)
  File "/Users/bhosmer/dev/pt/torch/onnx/utils.py", line 336, in _model_to_graph
    fixed_batch_size=fixed_batch_size, params_dict=params_dict)
  File "/Users/bhosmer/dev/pt/torch/onnx/utils.py", line 152, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/Users/bhosmer/dev/pt/torch/onnx/__init__.py", line 187, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/Users/bhosmer/dev/pt/torch/onnx/utils.py", line 710, in _run_symbolic_function
    return op_fn(g, *inputs, **attrs)
  File "/Users/bhosmer/dev/pt/torch/onnx/symbolic_helper.py", line 129, in wrapper
    return fn(g, *args)
  [File "/Users/bhosmer/dev/pt/torch/onnx/symbolic_opset9.py", line 1790, in flatten
    end_dim = dim + end_dim
TypeError: unsupported operand type(s) for +: 'NoneType' and 'int' (occurred when translating flatten)

torch/onnx/symbolic_opset9.py, line 1790

@garymm
Copy link
Collaborator

garymm commented Aug 25, 2021

Seems to work on PyTorch 1.9.0. Note the example_outputs is produced by actually invoking the model, so it's guaranteed to be a valid example.

import torch
print(torch.__version__)

import torchvision

m = torchvision.models.resnet50()
x = torch.rand((1, 3, 224, 224))
traced_m = torch.jit.trace(m, x)
f = 'model.pt'
torch.jit.save(traced_m, f)
loaded_m = torch.jit.load(f)
torch.onnx._export(loaded_m, x, 'model.onnx', example_outputs=loaded_m(x))

Produces output:

1.9.0+cu102
/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py:46: UserWarning: You are exporting the model to ONNX while in training mode with 'train' parameter not specified. The model will default to inference mode export. If you wish to export a training amenable ONNX model, specify training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export().
  warnings.warn("You are exporting the model to ONNX while in training mode with "
/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py:341: UserWarning: Model has no forward function
  warnings.warn("Model has no forward function")
/usr/local/lib/python3.7/dist-packages/torch/onnx/symbolic_helper.py:715: UserWarning: ONNX export mode is set to inference mode, but operator batch_norm is set to training  mode. The model will be exported in inference, as specified by the export mode.
  training_mode + ", as specified by the export mode.")

@garymm garymm closed this as completed Aug 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

9 participants