Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

upsample_bilinear2d issue when exporting to onnx #22906

Open
bpinaya opened this issue Jul 16, 2019 · 25 comments
Open

upsample_bilinear2d issue when exporting to onnx #22906

bpinaya opened this issue Jul 16, 2019 · 25 comments

Comments

@bpinaya
Copy link

@bpinaya bpinaya commented Jul 16, 2019

馃悰 Bug

This issue is related to #20116 and #10942 and has to deal with upsample_bilinear2d

To Reproduce

Steps to reproduce the behavior:

  1. This snippet can be used to reproduce behavior
import torch
import torch.nn as nn
import torch.nn.functional as F

class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        return x

torch_model = TestModel()
dummy_input = torch.randn(1, 3, 256, 256)

torch_out = torch.onnx.export(torch_model, dummy_input, 'test_model.onnx', verbose=True)

Expected behavior

That snippet is working under torch.__version__ : 1.0.1.post2, installed in a conda env, it seems that there it's able to parse to onnx since the output is:

graph(%input : Float(1, 3, 256, 256)) {
  %1 : Tensor = onnx::Constant[value= 1  1  2  2 [ CPUFloatType{4} ]](), scope: TestModel
  %2 : Float(1, 3, 512, 512) = onnx::Upsample[mode="linear"](%input, %1), scope: TestModel
  return (%2);
}

Under torch.__version__ : 1.1.0 the error would be this one:

Traceback (most recent call last):
  File "test.py", line 16, in <module>
    torch_out = torch.onnx.export(torch_model, dummy_input, 'test_model.onnx', verbose=True)
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/__init__.py", line 25, in export
    return utils.export(*args, **kwargs)
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/utils.py", line 131, in export
    strip_doc_string=strip_doc_string)
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/utils.py", line 363, in _export
    _retain_param_name, do_constant_folding)
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/utils.py", line 278, in _model_to_graph
    _disable_torch_constant_prop=_disable_torch_constant_prop)
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/utils.py", line 188, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/__init__.py", line 50, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/utils.py", line 589, in _run_symbolic_function
    return fn(g, *inputs, **attrs)
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/symbolic.py", line 130, in wrapper
    args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/symbolic.py", line 130, in <listcomp>
    args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
  File "/home/bpinaya/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/onnx/symbolic.py", line 90, in _parse_arg
    raise RuntimeError("Failed to export an ONNX attribute, "
RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible

And in a fresh conda env with pytoch installed from master with commit 3b1c3996e1c82ca8f43af9efa196b33e36efee37 and torch.__version__ : 1.2.0a0+3b1c399 the error for the upsample_bilinear2d is:

/home/bpinaya/anaconda3/envs/pytorch-edge/lib/python3.7/site-packages/torch/onnx/utils.py:662: UserWarning: ONNX export failed on ATen operator upsample_bilinear2d because torch.onnx.symbolic_opset9.upsample_bil
inear2d does not exist
  .format(op_name, opset_version, op_name))
Traceback (most recent call last):
  File "convert_to_onnx.py", line 36, in <module>
    output_names=output_layer_names))
  File "/home/bpinaya/anaconda3/envs/pytorch-edge/lib/python3.7/site-packages/torch/onnx/__init__.py", line 32, in export
    return utils.export(*args, **kwargs)
  File "/home/bpinaya/anaconda3/envs/pytorch-edge/lib/python3.7/site-packages/torch/onnx/utils.py", line 170, in export
    example_outputs=example_outputs, strip_doc_string=strip_doc_string, dynamic_axes=dynamic_axes)
  File "/home/bpinaya/anaconda3/envs/pytorch-edge/lib/python3.7/site-packages/torch/onnx/utils.py", line 429, in _export
    _retain_param_name, do_constant_folding)
  File "/home/bpinaya/anaconda3/envs/pytorch-edge/lib/python3.7/site-packages/torch/onnx/utils.py", line 330, in _model_to_graph
    _disable_torch_constant_prop=_disable_torch_constant_prop)
  File "/home/bpinaya/anaconda3/envs/pytorch-edge/lib/python3.7/site-packages/torch/onnx/utils.py", line 232, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/bpinaya/anaconda3/envs/pytorch-edge/lib/python3.7/site-packages/torch/onnx/__init__.py", line 57, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/bpinaya/anaconda3/envs/pytorch-edge/lib/python3.7/site-packages/torch/onnx/utils.py", line 663, in _run_symbolic_function
    op_fn = sym_registry.get_registered_op(op_name, '', opset_version)
  File "/home/bpinaya/anaconda3/envs/pytorch-edge/lib/python3.7/site-packages/torch/onnx/symbolic_registry.py", line 91, in get_registered_op
    return _registry[(domain, version)][opname]
KeyError: 'upsample_bilinear2d'

And according to #20116 the error should be fixed on master

Environment

The output of collect_env:
Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti

Nvidia driver version: 418.56
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.5.0

Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] numpydoc==0.8.0
[pip] torch==1.1.0
[pip] torchvision==0.3.0
[conda] blas 1.0 mkl
[conda] mkl 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.1.0 py3.7_cuda10.0.130_cudnn7.5.1_0 pytorch
[conda] torchvision 0.3.0 py37_cu10.0.130_1 pytorch

Bear in mind that I tried on pytorch 1.0.1, 1.1.0 and master.

Additional context

If I were to use something like:

import torch
import torch.nn as nn
import torch.nn.functional as F

class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()

    def forward(self, x):
        x = F.interpolate(x, x.size()[2:], mode='bilinear', align_corners=False)
        return x

torch_model = TestModel()
dummy_input = torch.randn(1, 3, 256, 256)

torch_out = torch.onnx.export(torch_model, dummy_input, 'test_model.onnx', verbose=True)

without manually specifying the sizes, the issue still remains and now also appears on version 1.0.1 of course.
Any insights are appreciated, thanks for the hard work guys!

@dvd42
Copy link

@dvd42 dvd42 commented Jul 17, 2019

Same issue here on pytorch '1.2.0.dev20190716' Installed using pip

Loading

@murthy95
Copy link

@murthy95 murthy95 commented Jul 18, 2019

I face the same issue in version 1.1.0 while passing size instead of scale_factor as mentioned by @bpinaya under addtional context. Is it fixed?

Loading

@bpinaya
Copy link
Author

@bpinaya bpinaya commented Jul 18, 2019

@murthy95 no updates yet :/

Loading

@GeoffreyChen777
Copy link

@GeoffreyChen777 GeoffreyChen777 commented Jul 21, 2019

same here.

Loading

@bhack
Copy link

@bhack bhack commented Jul 21, 2019

Is this more generally related to ONNX dynamic shape at onnx/onnx#654?

Loading

@sraczynski-babblelabs
Copy link

@sraczynski-babblelabs sraczynski-babblelabs commented Aug 1, 2019

Waiting for this one too.

Loading

@hackeritchy
Copy link
Contributor

@hackeritchy hackeritchy commented Aug 3, 2019

Same issue here on pytorch '1.2.0.dev20190802' Installed using pip

Loading

@vimalthilak
Copy link

@vimalthilak vimalthilak commented Aug 5, 2019

One of the issues here is that the definition for upsample_blinear2d is missing. I was able to workaround that by adding the following lines to symbolic_opset9.py

upsample_bilinear1d = _interpolate('upsample_bilinear1d', 3, "linear")
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear")
upsample_bilinear3d = _interpolate('upsample_bilinear3d', 5, "linear")

right below where upsample_nearest* are defined. With the above change, I can now trace my model without encountering the error mentioned above.

@pk-g , @houseroad, @zou3519 or other PyTorch devs: Is there a reason why upsample_bilinear* definitions are missing? Thanks

Loading

@nithinsubbiah
Copy link

@nithinsubbiah nithinsubbiah commented Aug 5, 2019

@vimalthilak That worked, thanks a lot

Loading

@lxgyChen
Copy link

@lxgyChen lxgyChen commented Aug 13, 2019

@vimalthilak I put these three lines to symbolic_opset9.py but got another error:
NameError: name '_interpolate' is not defined, my PyTorch version is 1.2.0

Loading

@vimalthilak
Copy link

@vimalthilak vimalthilak commented Aug 13, 2019

@vimalthilak I put these three lines to symbolic_opset9.py but got another error:
NameError: name '_interpolate' is not defined, my PyTorch version is 1.2.0

I see a function called _interpolate defined in symbolic_opset9.py at/around line 719. I checked both macOS and my ubuntu install and was able to spot that function. If you really feel up for it, you can build PyTorch from source after defining the lines in my previous note. _interpolate is available in this file:

def _interpolate(name, dim, interpolate_mode):

Loading

@lxgyChen
Copy link

@lxgyChen lxgyChen commented Aug 13, 2019

@vimalthilak I put these three lines to symbolic_opset9.py but got another error:
NameError: name '_interpolate' is not defined, my PyTorch version is 1.2.0

I see a function called _interpolate defined in symbolic_opset9.py at/around line 719. I checked both macOS and my ubuntu install and was able to spot that function. If you really feel up for it, you can build PyTorch from source after defining the lines in my previous note. _interpolate is available in this file:

def _interpolate(name, dim, interpolate_mode):

it works, thanks a lot!

Loading

@suruoxi
Copy link

@suruoxi suruoxi commented Aug 17, 2019

@vimalthilak

Thanks, it works.
However , is seems 'allign_corners = True' is not supported as reported in #10446

Loading

@KleinXin
Copy link

@KleinXin KleinXin commented Sep 5, 2019

change the upsampling mode to 'nearest' can solve this problem
'bilinear' interpolation will be supported in the newer version as mentioned here

Loading

@guissart
Copy link

@guissart guissart commented Sep 24, 2019

hello,
@vimalthilak thanks to your solution.
it appere that it doesn't produce the same output, ONNX have a different align_corners=False interpolation implementation that pytorch.
here a piece of code to compare :

import torch
import torch.nn as nn
import torch.nn.functional as F
import onnxruntime

class TestModel(nn.Module):
    def __init__(self, align=False):
        super(TestModel, self).__init__()
        self.align=align
    def forward(self, x):
        x = F.interpolate(x, (4,4), mode='bilinear', align_corners=self.align)
        return x

x = torch.tensor([[0.,1.],[2.,3.]]).view([1,1,2,2])
model = TestModel(align=True)
out = model(x)

print("matrix to be interpolated :")
print(x)
print()
print( "pytorch align=True output")
print(out)
print()
model = TestModel(align=False)
out = model(x)
print( "pytorch align=False output")
print(out)
print()

torch.onnx.export(model, x, "test_model.onnx")

ort_session = onnxruntime.InferenceSession("test_model.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
print("ONNX align=False output")
print(ort_outs)

and the output :

matrix to be interpolated :
tensor([[[[0., 1.],
          [2., 3.]]]])

pytorch align=True output
tensor([[[[0.0000, 0.3333, 0.6667, 1.0000],
          [0.6667, 1.0000, 1.3333, 1.6667],
          [1.3333, 1.6667, 2.0000, 2.3333],
          [2.0000, 2.3333, 2.6667, 3.0000]]]])

pytorch align=False output
tensor([[[[0.0000, 0.2500, 0.7500, 1.0000],
          [0.5000, 0.7500, 1.2500, 1.5000],
          [1.5000, 1.7500, 2.2500, 2.5000],
          [2.0000, 2.2500, 2.7500, 3.0000]]]])

ONNX align=False output
[array([[[[0. , 0.5, 1. , 1. ],
         [1. , 1.5, 2. , 2. ],
         [2. , 2.5, 3. , 3. ],
         [2. , 2.5, 3. , 3. ]]]], dtype=float32)]

That is far differents behaviour...

Loading

@dashesy
Copy link
Contributor

@dashesy dashesy commented Oct 8, 2019

@guissart I use this to patch for bilinear2d but unfortunately (as you mentioned) I also get same different results.

def patch_interpolate_opset10():
    """Patch interpolate in opset10
    """
    import torch.onnx.symbolic_helper as sym_help
    import torch.onnx.symbolic_opset10

    # noinspection PyProtectedMember
    def _interpolate_size_to_scales(g, input, output_size, dim):
        output_size = sym_help._maybe_get_const(output_size, 'is')
        if sym_help._is_value(output_size):
            offset = 2
            offsets = g.op("Constant", value_t=torch.ones(offset))
            dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"])
            divisor = sym_help._slice_helper(g, g.op("Shape", input), axes=[0], ends=[dim], starts=[offset])
            divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"])
            scale_dims = g.op("Div", dividend, divisor)
            scales = g.op("Concat", offsets, scale_dims, axis_i=0)
        else:
            scales_constant = [1. if i < 2 else
                               float(output_size[-(dim - i)]) / float(input.type().sizes()[-(dim - i)])
                               for i in range(0, dim)]
            scales = g.op("Constant", value_t=torch.tensor(scales_constant))
        return scales

    # noinspection PyProtectedMember
    def _interpolate(name, dim, interpolate_mode):
        # noinspection PyShadowingBuiltins, PyProtectedMember
        def symbolic_fn(g, input, output_size, align_corners=None):
            align_corners = sym_help._maybe_get_scalar(align_corners)
            if align_corners:
                return Exception(name, "align_corners == True")
            scales = _interpolate_size_to_scales(g, input, output_size, dim)
            return g.op("Resize", input, scales, mode_s=interpolate_mode)
        return symbolic_fn

    torch.onnx.symbolic_opset10.upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest")
    torch.onnx.symbolic_opset10.upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest")
    torch.onnx.symbolic_opset10.upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest")
    torch.onnx.symbolic_opset10.upsample_linear1d = _interpolate('upsample_linear1d', 3, "linear")
    torch.onnx.symbolic_opset10.upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear")
    torch.onnx.symbolic_opset10.upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, "linear")
matrix to be interpolated :
tensor([[[[0., 1.],
          [2., 3.]]]])

pytorch align=True output
tensor([[[[0.0000, 0.2500, 0.7500, 1.0000],
          [0.5000, 0.7500, 1.2500, 1.5000],
          [1.5000, 1.7500, 2.2500, 2.5000],
          [2.0000, 2.2500, 2.7500, 3.0000]]]])

pytorch align=False output
tensor([[[[0.0000, 0.2500, 0.7500, 1.0000],
          [0.5000, 0.7500, 1.2500, 1.5000],
          [1.5000, 1.7500, 2.2500, 2.5000],
          [2.0000, 2.2500, 2.7500, 3.0000]]]])

ONNX align=False output
[array([[[[0. , 0.5, 1. , 1. ],
         [1. , 1.5, 2. , 2. ],
         [2. , 2.5, 3. , 3. ],
         [2. , 2.5, 3. , 3. ]]]], dtype=float32)]

Did you find a workaround?


P.S. I am running nightly PyTorch and noticed align=True and align=False now have same results there.

Loading

@guissart
Copy link

@guissart guissart commented Oct 9, 2019

@dashesy I didn't find anything, I need it to do pyramid pooling and it strongly affect the behaviour of my model. I can't use ONNX for production for the moment because of that. A fix could be to upsample to (N+1, N+1) and the get reed of the +1, it may produce the same result as align corner...

Loading

@dashesy
Copy link
Contributor

@dashesy dashesy commented Oct 9, 2019

@guissart I use this simple tensor-based function for bi-linear interpolation, and it can be converted to ONNX too. It is easy to change as needed. One less reason to have a special Resize ops in ONNX unless that perf is better than matrix multiply with vector broadcasts.
If you do manage to change it for your use-case with align corner would be nice if you share it :--)

Loading

@glenn-jocher
Copy link

@glenn-jocher glenn-jocher commented Dec 3, 2019

I can not export from PyTorch to ONNX using the upsample operator with torch version 1.3.1,
onnx version 1.5.0.

MINIMUM CODE TO REPRODUCE:

import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx

print(torch.__version__)
print(onnx.__version__)


class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return x


torch_model = TestModel()
dummy_input = torch.randn(1, 3, 256, 256)

torch_out = torch.onnx.export(torch_model, dummy_input, 'model.onnx', verbose=True, opset_version=11)

onnx_model = onnx.load('model.onnx')
print(onnx_model)
onnx.checker.check_model(onnx_model)

Produces the following error:

Traceback (most recent call last):
  File "/Users/glennjocher/PycharmProjects/iD/upsample_error_reproduce.py", line 26, in <module>
    onnx.checker.check_model(onnx_model)
  File "/Users/glennjocher/.conda/envs/yolov3/lib/python3.7/site-packages/onnx/checker.py", line 86, in check_model
    C.check_model(model.SerializeToString())
onnx.onnx_cpp2py_export.checker.ValidationError: Node () has input size 4 not in range [min=2, max=2].

==> Context: Bad node spec: input: "input" input: "23" input: "23" input: "22" output: "24" op_type: "Resize" attribute { name: "coordinate_transformation_mode" s: "asymmetric" type: STRING } attribute { name: "cubic_coeff_a" f: -0.75 type: FLOAT } attribute { name: "mode" s: "nearest" type: STRING } attribute { name: "nearest_mode" s: "floor" type: STRING }

Loading

@tkclimb
Copy link

@tkclimb tkclimb commented Dec 13, 2019

I found something different but might have some relation to this issue.
I put the error below. is it a bug or the normal behavior, which means the specific combination of upsample_bilinear2d and align_corners=True is not convertible to onnx format??

I used torch 1.3.1 and onnx 1.6.0.

/home/linuxbrew/.linuxbrew/opt/pyenv/versions/pytorch/lib/python3.7/site-packages/torch/onnx/symbolic_helper.py:198: UserWarning: You are trying to export the model with onnx:Upsample for ONNX opset version 9. This operator might cause results to not match the expected results by PyTorch.
ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).
We recommend using opset 11 and above for models using this operator. 
  "" + str(_export_onnx_opset_version) + ". "
/home/linuxbrew/.linuxbrew/opt/pyenv/versions/pytorch/lib/python3.7/site-packages/torch/onnx/symbolic_helper.py:168: UserWarning: ONNX export failed on upsample_bilinear2d because align_corners == True not supported
  warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
 in 
     30 input_names += [k for k in state_dict.keys()]
     31 output_names = [ "output1" ]
---> 32 torch.onnx.export(net, x, "pspnet50.onnx")
     33 
     34 

/home/linuxbrew/.linuxbrew/opt/pyenv/versions/pytorch/lib/python3.7/site-packages/torch/onnx/__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
    141                         operator_export_type, opset_version, _retain_param_name,
    142                         do_constant_folding, example_outputs,
--> 143                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
    144 
    145 

/home/linuxbrew/.linuxbrew/opt/pyenv/versions/pytorch/lib/python3.7/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
     64             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
     65             example_outputs=example_outputs, strip_doc_string=strip_doc_string,
---> 66             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
     67 
     68 

/home/linuxbrew/.linuxbrew/opt/pyenv/versions/pytorch/lib/python3.7/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size)
    392             proto, export_map = graph._export_onnx(
    393                 params_dict, opset_version, dynamic_axes, defer_weight_export,
--> 394                 operator_export_type, strip_doc_string, val_keep_init_as_ip)
    395         else:
    396             proto, export_map = graph._export_onnx(

RuntimeError: ONNX export failed: Couldn't export operator aten::upsample_bilinear2d

Loading

@tkclimb
Copy link

@tkclimb tkclimb commented Dec 13, 2019

I solved above problem by just passing opset version torch.onnx.export(..., opset_version=11, ...) .

Loading

@absorbguo
Copy link

@absorbguo absorbguo commented May 9, 2020

same error occurred with torch==1.2.0

Loading

@AminAnsarian
Copy link

@AminAnsarian AminAnsarian commented Aug 20, 2020

I solved above problem by just passing opset version torch.onnx.export(..., opset_version=11, ...) .

Worked! Thanks!

Loading

@GraceKafuu
Copy link

@GraceKafuu GraceKafuu commented Aug 26, 2020

I solved above problem by just passing opset version torch.onnx.export(..., opset_version=11, ...) .

torch updated to 1.6.0, worked, Thank you!

Loading

@dragen1860
Copy link

@dragen1860 dragen1860 commented Sep 17, 2020

The most embarassing is :

  1. When set align_corners=False, you can normally export to ONNX and converted to CoreML, but you can not get consistent result since Coreml use 'align_corner=True' internally.
  2. When set align_corners=True and opset_version=11, you usually can struggle to export to ONNX, but you can not convert to CoreML.

Loading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet