Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't convert Upsample to onnx #18113

Closed
E1eMenta opened this issue Mar 17, 2019 · 41 comments
Closed

Can't convert Upsample to onnx #18113

E1eMenta opened this issue Mar 17, 2019 · 41 comments
Assignees
Labels
high priority module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@E1eMenta
Copy link

馃悰 Bug

pytorch == 1.0.1.post2
onnx == 1.4.1
I'm trying to convert 'upsample' op from pytorch to onnx.
Code:

class Test(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):

        return F.upsample(x, size=(x.shape[2]*2,x.shape[3]*2), mode='bilinear', align_corners=True)

model = Test()
x = torch.zeros((1, 3, 300, 300))
torch.onnx._export(model, x, "upsample.onnx", export_params=True, verbose=True)

Output error:

Traceback (most recent call last):
  File "/mnt/media2/renat/sync/research/detection/test.py", line 30, in <module>
    torch.onnx._export(model, x, "upsample.onnx", export_params=True, verbose=True)
  File "/home/renatkhiz/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 22, in _export
    return utils._export(*args, **kwargs)
  File "/home/renatkhiz/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 281, in _export
    example_outputs, propagate)
  File "/home/renatkhiz/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 227, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type)
  File "/home/renatkhiz/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 155, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/renatkhiz/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 52, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/renatkhiz/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 504, in _run_symbolic_function
    return fn(g, *inputs, **attrs)
  File "/home/renatkhiz/anaconda3/lib/python3.6/site-packages/torch/onnx/symbolic.py", line 88, in wrapper
    args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
  File "/home/renatkhiz/anaconda3/lib/python3.6/site-packages/torch/onnx/symbolic.py", line 88, in <listcomp>
    args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
  File "/home/renatkhiz/anaconda3/lib/python3.6/site-packages/torch/onnx/symbolic.py", line 45, in _parse_arg
    raise RuntimeError("ONNX symbolic expected a constant value in the trace")
RuntimeError: ONNX symbolic expected a constant value in the trace

I've tried different modes and nn.interpolation, result is the same. What is the problem?

@pytorchbot pytorchbot added the module: onnx Related to torch.onnx label Mar 17, 2019
@Protocal13
Copy link

Either your forgeting the " " somewhere and did you forget to add these somwhere *args **kwargs and i thinks you should take your time to look over this error output and carefully read your code just to see any mistakes

@Protocal13
Copy link

;)

@E1eMenta
Copy link
Author

@Protocal13, thank you for answer, but can you clerify what is the problem in 'space' and *args **kwargs?

@soumith
Copy link
Member

soumith commented Mar 18, 2019

@Protocal13 your answers on this issue and other issues are wrong or not helpful. Please only make comments if you think they are the right answer / direction.

@pepsin
Copy link

pepsin commented Mar 19, 2019

I have this issue too in my Resnet based project, and this bug exist from 1.0 till now.

@yxchng
Copy link

yxchng commented Mar 21, 2019

Any updates?

@libohit
Copy link

libohit commented Mar 21, 2019

I face the same problem about upsample.

@Sean0123456789
Copy link

I am trying to convert UNet to Caffe2 using ONNX and I am also facing the same problem as others. Will it be solved soon?

@ezyang ezyang added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 7, 2019
@luoan
Copy link

luoan commented Apr 10, 2019

@E1eMenta

i face the problem too
here is my test and result, it seems at present, we can only use specific numbers for parameters, not variable parameters, and align_corners=True is not supported yet

import torch.nn as nn
import torch
import torch.nn.functional as F

class Test(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):

        #return F.upsample(x, size=(x.shape[2] * 2, x.shape[3] * 2), mode='bilinear', align_corners=True)
                                # RuntimeError: ONNX symbolic expected a constant value in the trace

        #return F.interpolate(x, size=(x.shape[2] * 2, x.shape[3] * 2), mode='bilinear', align_corners=True)
                                # RuntimeError: ONNX symbolic expected a constant value in the trace

        #return F.upsample(x, size=(600, 600), mode='bilinear', align_corners=False)
                                # UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.

        #return F.interpolate(x, size=(600, 600), mode='bilinear', align_corners=True)
                                # UserWarning: ONNX export failed on upsample_bilinear2d because align_corners == True not supported
                                # RuntimeError: ONNX export failed: Couldn't export operator aten::upsample_bilinear2d

        return F.interpolate(x, size=(600, 600), mode='bilinear', align_corners=False) #no warning, all clear

model = Test()
x = torch.zeros((1, 3, 300, 300))
torch.onnx._export(model, x, "test.onnx", verbose=True)

@WenmuZhou
Copy link

meet too

@aidonchuk
Copy link

Pls!!! Fix this! ASAP!!!

@daquexian
Copy link
Contributor

I believe it has been fixed in 93d5503 four hours ago :D

@Aeroxander
Copy link

The conversion from PyTorch to ONNX works fine, only when I go from ONNX to OpenVINO I get this error:

[ ERROR ] Unexpected exception happened during extracting attributes for node 43.
Original exception message: One/both of widths_scale = None and height_scale = None is not defined for Upsample node 44.

Dont know if this has anything to do with it.

@ezyang
Copy link
Contributor

ezyang commented Jun 18, 2019

cc @pk-g @houseroad re Aeroxander's latest comment

@houseroad
Copy link
Member

@Aeroxander do you want to provide your onnx model? I think it may be due to different opset version.

@Aeroxander
Copy link

Aeroxander commented Jun 18, 2019

@houseroad sure! Here it is: decoder.onnx
Upsample is supported in OpenVINO, so it seems the a certain the scale parameters aren't converted to the ONNX model. The opset version is v9!

Edit: I was able to convert it from v9 to v8 with ONNX and now the "Upsample" problem is solved!

@jjhw
Copy link

jjhw commented Jun 30, 2019

@Aeroxander How did you do the conversion to v8? When I tried with my model using a python script with ONNX 1.5.0 at the line-

converted_model = version_converter.convert_version(original_model, 8)

it gives the following error-

adapt_upsample_9_8: Assertion false failed: Unsuppported conversion due to unavailable input: scale

@ftaralle
Copy link

ftaralle commented Jul 2, 2019

@houseroad sure! Here it is: decoder.onnx
Upsample is supported in OpenVINO, so it seems the a certain the scale parameters aren't converted to the ONNX model. The opset version is v9!

Edit: I was able to convert it from v9 to v8 with ONNX and now the "Upsample" problem is solved!

Hi,
I am trying do use a UNet-like network.
It uses some "Upsample" operators and I get the same error you described.

[ ERROR ] Unexpected exception happened during extracting attributes for node 83.

Here is the code to export Pytorch to ONNX:

torch_out = torch.onnx._export(net, x, args.input_names=['input'], output_names=['output'], export_params=True, verbose=args.verbose)

Here is the verbose information of the onnx export:

%174 : Float(1, 1, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1]](%173, %109), scope: UNet3/Conv2d[f_layer]
%175 : Tensor = onnx::Constant[value= 1  1  2  2 [ Variable[CPUType]{4} ]](), scope: UNet3
%output : Float(1, 1, 512, 512) = onnx::Upsample[mode="nearest"](%174, %175), scope: UNet3

If i understand properly your solution, it consist in forcing the V8 of the Upsample operator when exporting from Pytorch to ONNX. But how, do you achieve this ?

Thank's a lot

@Aeroxander
Copy link

Aeroxander commented Jul 2, 2019

@ftaralle it should work as @jjhw described but he gets an error, wouldn't know why though, as scale should be supported, but it might be in a different name.?

@jjhw
Copy link

jjhw commented Jul 2, 2019

@Aeroxander @ftaralle The problem is PyTorch does not put the scale values in the Upsample layer, I have not tried to change the PyTorch code that generates the ONNX output as I am using ONNX only as an intermediate stage to OpenVino so I have hacked the OpenVino code to set the scale values to 2.0. If you wanted to change the ONNX file you could either rewrite the PyTorch exporter to add the scale values or alternatively write a script that deserializes the file afterwards as ONNX is a protobuf then make the corrections and serialize it back out again.

@ftaralle
Copy link

ftaralle commented Jul 4, 2019

Hi, following the suggestion of @jjhw I finaly managed to make the Upsample accepted by OpenVINO
Here is the python3 code, in case someone needs it too:

import onnx
from onnx import version_converter, helper

# load model
original_model  = onnx.load(model_path)

# converts oppset v9 to v8
converted_model = version_converter.convert_version(original_model, 8)

# change attribute of all Upsample nodes
for node in converted_model.graph.node:
    if node.op_type == 'Upsample':
        # get id-attribute_name map
        id = { attribute.name: id for id, attribute in enumerate(node.attribute)}
        # get & remove "scales" attribute
        att_scales = node.attribute.pop(id['scales']) 
        _, _, scale_height, scale_width = att_scales.floats # CARE IT DEPENDS ON ORDER. HERE [B, C, W, H] IS EXPECTED
        # append new attributes 'scale_width' & 'scale_height'
        node.attribute.extend([
            helper.make_attribute('width_scale', scale_width),
            helper.make_attribute('height_scale', scale_height)
        ])

# save
onnx.save(converted_model, result_path)

Here are OpenVINO's error messages that i followed:

  • providing opset v9 --> Upsample scales attribute is defined for node 148. Only scale_width and scale_height are supported.
  • providing opset v8 --> One/both of widths_scale = None and height_scale = None is not defined for Upsampe node 157.
  • using widths_scale --> One/both of widths_scale = None and height_scale = 2.0 is not defined for Upsampe node 148.

To be noted the miss spelling in the last error message :p widths_scale -> width_scale

@kzjeef
Copy link

kzjeef commented Jul 18, 2019

Thanks @ftaralle, your python code works for me.

@guoguangchao
Copy link

@ftaralle thanks for your code, when i run this code, i got this error:

converted_model = version_converter.convert_version(original_model, 8)

File "/usr/local/lib/python3.6/dist-packages/onnx/version_converter.py", line 166, in convert_version
converted_model_str = C.convert_version(model_str, target_version)
RuntimeError: /onnx/onnx/version_converter/adapters/upsample_9_8.h:78: adapt_upsample_9_8: Assertion false failed: Unsuppported conversion due to unavailable input: scale

i dont know how to fix it, can you give me some advice? thanks

@ftaralle
Copy link

ftaralle commented Sep 30, 2019

Hi @guoguangchao
Could you provide a print of your original_model.graph (for a upsample layer) ?
My guess is that a "scale" attribute is missing for a 'upsample' layer.

@guoguangchao
Copy link

guoguangchao commented Sep 30, 2019

@ftaralle thanks for your reply, in the original_model.graph one of the upsample layer is as follows:

node {
  input: "905"
  input: "921"
  output: "922"
  op_type: "Upsample"
  attribute {
    name: "mode"
    s: "nearest"
    type: STRING
  }
}

I used the FPN structure in the model锛宼here is F.interpolate method in the model.The definition is as follows:

class FPN(nn.Module):
    def __init__(self,in_channels_list,out_channels):
        super(FPN,self).__init__()
        self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1)
        self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1)
        self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1)
        self.merge1 = conv_bn(out_channels, out_channels)
        self.merge2 = conv_bn(out_channels, out_channels)
    def forward(self, input):
        names = list(input.keys())
        input = list(input.values())
        output1 = self.output1(input[0])
        output2 = self.output2(input[1])
        output3 = self.output3(input[2])
        up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
        output2 = output2 + up3
        output2 = self.merge2(output2)
        up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
        output1 = output1 + up2
        output1 = self.merge1(output1)
        out = [output1, output2, output3]
        return out

Printing of FPN:

(fpn): FPN(
    (output1): Sequential(
      (0): Conv2d(40, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (output2): Sequential(
      (0): Conv2d(160, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (output3): Sequential(
      (0): Conv2d(960, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (merge1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (merge2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )

@ftaralle
Copy link

ftaralle commented Oct 1, 2019

So i guess it is about your up3 & up2:

  • up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
  • up2 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")

You are using fixed-size resising operation. So indeed, there is no scaling factor in here.
In your graph's node, you have 2 inputs @905 & @921
I'm not totaly sure but i guess that one is the input (result of previous layer) and the other is a vector containing the target shape.

I'm not sure, but i think that scaling with a fixed size is not supported (yet).
I remember loosing part of my mind on such strange behaviors...

If i understand properly, your are processing, in parrallel 3 inputs then merge them.
It is not the solution you came for, but is it possible to reshape them before feeding the network with them ? That way you would avoid the reshaping inside the network.

;p

@guoguangchao
Copy link

@ftaralle Thanks for your advice.

@houseroad
Copy link
Member

Does the problem still exist in master?

cc: @lara-hdr @spandantiwari

@lara-hdr
Copy link
Contributor

lara-hdr commented Oct 2, 2019

@guoguangchao, the Upsample node you copied in your comment seems correct for opset 9.
The first input (905) is your 'output3' and the second one (921) is the scales.

If I understand correctly you are trying to convert the model from opset 9 to opset 8?
Since PyTorch 1.2, you can export your PyTorch models directly to opset 8 without passing by the converter (by adding the arg opset_version=8 in the export api).

@guoguangchao
Copy link

@lara-hdr Thanks for your reply, My Pytorch verson is 1.2.0, I try to export the model in the following way:
torch_out = torch.onnx.export(net, x, "detect.onnx", opset_version=8)
I got this error:

/torch/onnx/symbolic_helper.py:185: UserWarning: ONNX export failed on upsample_nearest2d because torch._C.Value (output_size) indexing not supported
  warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
Traceback (most recent call last):
  File "export_to_onnx.py", line 77, in <module>
    torch_out = torch.onnx.export(net, x, "detect.onnx", opset_version=8)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 132, in export
    strip_doc_string, dynamic_axes)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 64, in export
    example_outputs=example_outputs, strip_doc_string=strip_doc_string, dynamic_axes=dynamic_axes)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 340, in _export
    params_dict, opset_version, dynamic_axes, defer_weight_export, operator_export_type, strip_doc_string)
RuntimeError: ONNX export failed: Couldn't export operator aten::upsample_nearest2d

@kealennieh
Copy link

kealennieh commented Nov 6, 2019

I have the same problem when I export F.interpolate. My torch version is 1.2.0. When I set opset_version as 9, everything is fine. However, problem happens when I set opset_version as 7.

Any suggestions ?

@spandantiwari
Copy link

@kealennieh - the ONNX version of the operator that support F.interpolate, onnx::Resize, has undergone significant changes since opset 7. Not all scenarios of F.interpolate can be supported in opset 7 version, which is one reason why the op was upgraded in onnx in subsequent versions. My suggestion would be consider using opset 9 (or even higher) with the latest PyTorch 1.3 (or even the nightly build is possible). Is there any reason you cannot use opset 9?

@kealennieh
Copy link

kealennieh commented Nov 7, 2019

@spandantiwari Thanks for your suggestion. The reason is that my current tensorrt can only support opset 7.
By the way, I've just found a trick to solve the problem a few minutes ago.

up = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")

to

up = F.interpolate(output2, size=[int(output1.size(2)), int(output1.size(3))], mode="nearest")

Then F.interpolate can be exported correctly in opset 7.

@voqtuyen
Copy link

@houseroad sure! Here it is: decoder.onnx
Upsample is supported in OpenVINO, so it seems the a certain the scale parameters aren't converted to the ONNX model. The opset version is v9!

Edit: I was able to convert it from v9 to v8 with ONNX and now the "Upsample" problem is solved!

Could you show me how to convert it from v9 to v8, i have the same problem

@Aeroxander
Copy link

@voqtuyen It's quite easy, just do it with the ONNX python API: https://github.com/onnx/onnx/blob/master/docs/PythonAPIOverview.md#converting-version-of-an-onnx-model-within-default-domain-aionnx

@nieyan
Copy link

nieyan commented Apr 3, 2020

ENV:
onnx 1.6.0
onnxruntime 0.5.0
pytorch 1.1.0

I can convert model to onnx, but output of pytorch and onnx do not match.
code:

import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np

import onnx
import onnxruntime

class Test(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return F.interpolate(x, size=(400, 600), mode='bilinear', align_corners=False) #no warning, all clear

model = Test()
x = torch.rand((1, 3, 200, 300))
torch.onnx._export(model, x, "test.onnx", verbose=True)

model.eval()
with torch.no_grad():
    torch_out = model(x)

ort_session = onnxruntime.InferenceSession("test.onnx")
ort_input = {ort_session.get_inputs()[0].name: x.cpu().numpy()}
ort_out = ort_session.run(None, ort_input)[0]
np.testing.assert_allclose(torch_out.cpu().numpy(), ort_out, rtol=1e-03, atol=1e-05)

output

graph(%x : Float(1, 3, 200, 300)):
  %1 : Tensor = onnx::Constant[value= 1  1  2  2 [ Variable[CPUType]{4} ]](), scope: Test
  %2 : Float(1, 3, 400, 600) = onnx::Upsample[mode="linear"](%x, %1), scope: Test
  return (%2)

Traceback (most recent call last):
  File "git_test.py", line 29, in <module>
    np.testing.assert_allclose(torch_out.cpu().numpy(), ort_out, rtol=1e-03, atol=1e-05)
  File "/opt/conda/envs/pytorch1.1.0/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1501, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/opt/conda/envs/pytorch1.1.0/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 827, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatch: 99.7%
Max absolute difference: 0.42252138
Max relative difference: 15440.802
 x: array([[[[8.301054e-01, 8.599483e-01, 9.196341e-01, ..., 9.512389e-01,
          8.782693e-01, 8.417845e-01],
         [7.183217e-01, 7.415588e-01, 7.880329e-01, ..., 8.880928e-01,...
 y: array([[[[8.301054e-01, 8.897913e-01, 9.494771e-01, ..., 9.147540e-01,
          8.417845e-01, 8.417845e-01],
         [6.065379e-01, 6.398005e-01, 6.730631e-01, ..., 7.911202e-01,...

Any suggestions ?

@blueardour
Copy link

@nieyan I met similar problem, do you have any progress on it?

@nieyan
Copy link

nieyan commented Apr 13, 2020

@blueardour

I think convert to onnx opset version 11 can get correct result . But I do need onnx file with opset version 9 for next step, which is convert to embedding device's special model format.

@blueardour
Copy link

hi, @nieyan

Thanks for the pointing out.
After I add opset_version=11 in the torch.onnx._export function of your above sample code, I met the following error. Could you share how to get the test passed?
My pytorch: 1.4.0 onnx:1.6.0

Traceback (most recent call last):
File "onnx-test.py", line 40, in
ort_session = onnxruntime.InferenceSession("test.onnx")
File "/home/linux/.pyenv/versions/3.6.8/lib/python3.6/site-packages/onnxruntime/capi/session.py", line 29, in init
self._sess.load_model(path_or_bytes)
RuntimeError: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from test.onnx failed:Node: Node () has input size 4 not in range [min=2, max=2]

@zdaiot
Copy link

zdaiot commented Jun 7, 2020

@blueardour try to upgrade your onnx.
@nieyan I have try to use opset version 11 but also get some mismatch . I also need onnx file with opset version 9 for next step, Have you solved it?

@garymm
Copy link
Collaborator

garymm commented Nov 19, 2021

Original model reported by @E1eMenta can be exported with opset_version=11 or higher.
Looking at the ONNX 1.6.0 release notes, there were changes in the Resize operator that I think are being used to do the conversion.

If someone still needs Upsample to be exported specifically for opset version 9, please open a new issue and please note what is going to consume the ONNX model so that we can prioritize the issue.

@garymm garymm closed this as completed Nov 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests