Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug or Requirement?] ONNX unpacked the tuple of parameters of the model #11456

Closed
thlinh opened this issue Sep 10, 2018 · 2 comments
Closed
Labels
module: onnx Related to torch.onnx

Comments

@thlinh
Copy link

thlinh commented Sep 10, 2018

If you have a question or would like help and support, please ask at our
forums.

If you are submitting a feature request, please preface the title with [feature request].
If you are submitting a bug report, please fill in the following details.

Issue description

Hi,
I have a problem when converting a PyTorch model into ONNX. My model expects its input is a tuple of tensors (one input only). But when running torch.onnx.export(), the input is unpacked and passed into the model as a number of tensors, which generates error (since only 1 parameter is expected). I have a simplified version shown down here:

Code example

import torch
import torch.nn as nn

class sim_test(nn.Module):
    def __init__(self, args):
        super(sim_test, self).__init__()
        print(args)

    def forward(self, inp_tuple):
        x = inp_tuple[0]
        print(x.type())
        print(x.size())
        y = inp_tuple[1]
        z = inp_tuple[2]
        return x + y + z

def main():
    model = sim_test("Test me")
    inp_x = torch.FloatTensor(1, 5).uniform_()
    inp_y = torch.FloatTensor(1, 5).uniform_()
    inp_z = torch.FloatTensor(1, 5).uniform_()
    total_inp = (inp_x,inp_y,inp_z)
    result = model(total_inp)
    print ("Result is: ", result)

    torch.onnx.export(model, total_inp, "simple_test.onnx", verbose=True)

if __name__ == '__main__':
    main()

The result screen is:

torch.FloatTensor
torch.Size([1, 5])
Result is:  tensor([[1.2815, 1.8511, 0.9751, 1.1096, 1.4900]])
Traceback (most recent call last):
  File "simple_test.py", line 30, in <module>
    main()
  File "simple_test.py", line 27, in main
    torch.onnx.export(model, total_inp, "simple_test.onnx", verbose=True)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/onnx/__init__.py", line 26, in export
    return utils.export(*args, **kwargs)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/onnx/utils.py", line 94, in export
    operator_export_type=operator_export_type)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/onnx/utils.py", line 227, in _export
    example_outputs, propagate)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/onnx/utils.py", line 178, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/onnx/utils.py", line 145, in _trace_and_get_graph_from_model
    trace, torch_out = torch.jit.get_trace_graph(model, args)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/jit/__init__.py", line 77, in get_trace_graph
    return LegacyTracedModule(f)(*args, **kwargs)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/jit/__init__.py", line 109, in forward
    out = self.inner(*trace_inputs)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 475, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/home/h00472437/.conda/envs/Py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 465, in _slow_forward
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given

The problem can be now by-passed by changing the definition of forward() to:

def forward(self, *inp_tuple):

and call it as:

result = model(*total_inp)

I wonder if this is a requirement that we cannot use directly tuple as input into a model?
Thank you very much for your coming helps! 😃

System Info

Collecting environment information...
PyTorch version: 0.4.1.post2
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.5.0-12ubuntu1~16.04) 5.5.0 20171010
CMake version: version 3.5.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 7.5.17
GPU models and configuration: GPU 0: Tesla K40c
Nvidia driver version: 396.51
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7.1.1
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn_static.a
/usr/local/cuda-9.2/targets/x86_64-linux/lib/libcudnn.so.7.1.4
/usr/local/cuda-9.2/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip3] numpy (1.15.1)
[pip3] torch (0.4.1.post2)
[pip3] torchsummary (1.4)
[pip3] torchtext (0.2.3)
[pip3] torchvision (0.2.1)
[conda] pytorch                   0.4.1           py36_py35_py27__9.0.176_7.1.2_2    pytorch
[conda] torchsummary              1.4                       <pip>
[conda] torchtext                 0.2.3                     <pip>
[conda] torchvision               0.2.1                    py36_1    pytorch
@soumith soumith added the module: onnx Related to torch.onnx label Sep 10, 2018
@houseroad
Copy link
Member

Fix: passing total_inp as a tuple, since it will be unpacked.
torch.onnx.export(model, (total_inp,), "simple_test.onnx", verbose=True)

@thlinh
Copy link
Author

thlinh commented Sep 13, 2018

@houseroad : Is the tuple, as an input parameter, always unpacked by ONNX when calling a function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx
Projects
None yet
Development

No branches or pull requests

3 participants