Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export of MaxUnpool2d is not supported #25088

Closed
habib-19 opened this issue Aug 23, 2019 · 34 comments
Closed

ONNX export of MaxUnpool2d is not supported #25088

habib-19 opened this issue Aug 23, 2019 · 34 comments
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@habib-19
Copy link

habib-19 commented Aug 23, 2019

馃悰 Bug

Unable to convert ENet to ONNX because of missing max_unpool2d error?

ONNX export failed on ATen operator max_unpool2d because torch.onnx.symbolic_opset9.max_unpool2d does not exist

ONNX issue told me to report in torch

To Reproduce

Steps to reproduce the behavior:

1.download the ENet model
1.use onnx to convert pytorch to onnx

Expected behavior

Environment

PyTorch version: 1.2.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX TITAN X
GPU 1: GeForce GTX TITAN X
GPU 2: GeForce GTX TITAN X
GPU 3: GeForce GTX TITAN X

Nvidia driver version: 410.48
cuDNN version: /usr/local/cuda-9.2/targets/x86_64-linux/lib/libcudnn.so.7.2.1

Versions of relevant libraries:
[pip3] numpy==1.16.4
[pip3] numpy-indexed==0.3.5
[pip3] torch==1.1.0
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.3.0
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl-service 2.0.2 py36h7b6447c_0
[conda] mkl_fft 1.0.14 py36ha843d7b_0
[conda] mkl_random 1.0.2 py36hd81dba3_0
[conda] pytorch 1.2.0 py3.6_cuda10.0.130_cudnn7.6.2_0 pytorch
[conda] torchvision 0.4.0 py36_cu100 pytorch

Additional context

I am using conda env to run the code

cc @BowenBao @neginraoof

@pytorchbot pytorchbot added the module: onnx Related to torch.onnx label Aug 23, 2019
@izdeby izdeby added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 23, 2019
@Nerdyvedi
Copy link

I am facing the same error. Looks like, ONNX does not support max_unpool2d .

@habib-19
Copy link
Author

I am facing the same error. Looks like, ONNX does not support max_unpool2d .

onnx people says issue it is torch bug.

@belgraviton
Copy link

Here and here were mentioned that 'MaxUnpool' onnx operator was added to opset 9 of ONNX.

Unfortunately I don't found any 'max_unpool2d' support in opsets (9, 10, 11) in torch (link).

Is there any guidance in way of adding unpool support to torch onnx?

@belgraviton
Copy link

@RasinGue
Copy link

RasinGue commented Oct 21, 2019

Hi, is there any progress here?
It seems pytorch has not support max_unpool2d yet. I dont find any onnx related max_unpool2d in there source code.

Here and here were mentioned that 'MaxUnpool' onnx operator was added to opset 9 of ONNX.

Unfortunately I don't found any 'max_unpool2d' support in opsets (9, 10, 11) in torch (link).

Is there any guidance in way of adding unpool support to torch onnx?

@fumihwh
Copy link
Contributor

fumihwh commented Apr 15, 2020

Currently, I think support max_unpool2d is not easy...
kernel shape is a problem....aten::max_unpool2d not need it(is used to compute output shape only)..but is required in onnx.

MORE锛孧axUnpool in onnx is ambiguous.. I will create an issue.

@xingyanan
Copy link

Hi, Hi, Hi, is there any progress here?

@YiruS
Copy link

YiruS commented Aug 20, 2020

Any update on this? I'm having the same issue

@jce2090
Copy link

jce2090 commented Sep 24, 2020

same problem

@PatrickNa
Copy link

@habib-19 same problem with the exact same use case. Were you able to transform ENet to ONNX? I am aware that 'max_unpool2d' is still not supported, but eventually you have replaced it?

@habib-19
Copy link
Author

habib-19 commented Oct 5, 2020

@habib-19 same problem with the exact same use case. Were you able to transform ENet to ONNX? I am aware that 'max_unpool2d' is still not supported, but eventually you have replaced it?

Hi PatrikNa, I was not able to used the conversation tool. (for pytorch to tensorflow)
I wrote the code in ENET in tensorflow after learning it,

@PatrickNa
Copy link

@habib-19 would you mind sharing your implementation? I would be interested in testing it with Tensorflow, too.

@habib-19
Copy link
Author

habib-19 commented Oct 6, 2020

@PatrickNa I did this last year in my previous company. Sorry, I don't have access to my implementation.

@PatrickNa
Copy link

That's alright. Was ENET suitable in you application back then - in accuracy and performance?

@habib-19
Copy link
Author

habib-19 commented Oct 6, 2020

@PatrickNa Yes for our application it was good, it was a little modified version of ENET. Both the Pytorch and TensorFlow implementations were acceptable

@tlk2abhishek
Copy link

Exporting the operator max_unpool2d to ONNX opset version 11 is not supported.

Please let me know if this issue has got fixed ?
Is there any work around solution available for this issue.

@cgnarendiran
Copy link

Hi, has anyone found a solution yet?

@paulgavrikov
Copy link

Still an issue with torch-1.8.0.dev20210125+cpu and opset=12. @izdeby is a pytorcher looking into supporting this any soon?

@XiaoLaoDi
Copy link

Is there any progress on this problem ?

@markson14
Copy link

markson14 commented Aug 16, 2021

check out this link. It provides a maxunpool2d module which is able to convert to onnx successfully.
https://medium.com/axinc/pytorch%E3%81%AEnn-maxunpool2d%E3%82%92onnx%E3%81%AB%E3%82%A8%E3%82%AF%E3%82%B9%E3%83%9D%E3%83%BC%E3%83%88%E3%81%99%E3%82%8B-f48deb65580a

@garymm garymm changed the title KeyError: 'max_unpool2d' pytorch to onnx ONNX export of MaxUnpool2d is not supported Oct 6, 2021
@garymm
Copy link
Collaborator

garymm commented Oct 6, 2021

This is being tracked internally at Microsoft by https://msdata.visualstudio.com/Vienna/_workitems/edit/1444696

@garymm
Copy link
Collaborator

garymm commented Apr 7, 2022

Copy pasting work-around from @markson14's link. I have not verified this is correct:

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.utils.weight_init import xavier_init
from torch.autograd import Function
from torch.nn.modules.pooling import _MaxUnpoolNd
from torch.nn.modules.utils import _pair

class MaxUnpool2dop(Function):
    """We warp the `torch.nn.functional.max_unpool2d`
    with an extra `symbolic` method, which is needed while exporting to ONNX.
    Users should not call this function directly.
    """

    @staticmethod
    def forward(ctx, input, indices, kernel_size, stride, padding,
                output_size):
        """Forward function of MaxUnpool2dop.
        Args:
            input (Tensor): Tensor needed to upsample.
            indices (Tensor): Indices output of the previous MaxPool.
            kernel_size (Tuple): Size of the max pooling window.
            stride (Tuple): Stride of the max pooling window.
            padding (Tuple): Padding that was added to the input.
            output_size (List or Tuple): The shape of output tensor.
        Returns:
            Tensor: Output tensor.
        """
        return F.max_unpool2d(input, indices, kernel_size, stride, padding,
                              output_size)

    @staticmethod
    def symbolic(g, input, indices, kernel_size, stride, padding, output_size):
        # get shape
        input_shape = g.op('Shape', input)
        const_0 = g.op('Constant', value_t=torch.tensor(0))
        const_1 = g.op('Constant', value_t=torch.tensor(1))
        batch_size = g.op('Gather', input_shape, const_0, axis_i=0)
        channel = g.op('Gather', input_shape, const_1, axis_i=0)

        # height = (height - 1) * stride + kernel_size
        height = g.op(
            'Gather',
            input_shape,
            g.op('Constant', value_t=torch.tensor(2)),
            axis_i=0)
        height = g.op('Sub', height, const_1)
        height = g.op('Mul', height,
                      g.op('Constant', value_t=torch.tensor(stride[1])))
        height = g.op('Add', height,
                      g.op('Constant', value_t=torch.tensor(kernel_size[1])))

        # width = (width - 1) * stride + kernel_size
        width = g.op(
            'Gather',
            input_shape,
            g.op('Constant', value_t=torch.tensor(3)),
            axis_i=0)
        width = g.op('Sub', width, const_1)
        width = g.op('Mul', width,
                     g.op('Constant', value_t=torch.tensor(stride[0])))
        width = g.op('Add', width,
                     g.op('Constant', value_t=torch.tensor(kernel_size[0])))

        # step of channel
        channel_step = g.op('Mul', height, width)
        # step of batch
        batch_step = g.op('Mul', channel_step, channel)

        # channel offset
        range_channel = g.op('Range', const_0, channel, const_1)
        range_channel = g.op(
            'Reshape', range_channel,
            g.op('Constant', value_t=torch.tensor([1, -1, 1, 1])))
        range_channel = g.op('Mul', range_channel, channel_step)
        range_channel = g.op('Cast', range_channel, to_i=7)  # 7 is int64

        # batch offset
        range_batch = g.op('Range', const_0, batch_size, const_1)
        range_batch = g.op(
            'Reshape', range_batch,
            g.op('Constant', value_t=torch.tensor([-1, 1, 1, 1])))
        range_batch = g.op('Mul', range_batch, batch_step)
        range_batch = g.op('Cast', range_batch, to_i=7)  # 7 is int64

        # update indices
        indices = g.op('Add', indices, range_channel)
        indices = g.op('Add', indices, range_batch)

        return g.op(
            'MaxUnpool',
            input,
            indices,
            kernel_shape_i=kernel_size,
            strides_i=stride)


class MaxUnpool2d(_MaxUnpoolNd):
    """This module is modified from Pytorch `MaxUnpool2d` module.
    Args:
      kernel_size (int or tuple): Size of the max pooling window.
      stride (int or tuple): Stride of the max pooling window.
          Default: None (It is set to `kernel_size` by default).
      padding (int or tuple): Padding that is added to the input.
          Default: 0.
    """

    def __init__(self, kernel_size, stride=None, padding=0):
        super(MaxUnpool2d, self).__init__()
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride or kernel_size)
        self.padding = _pair(padding)

    def forward(self, input, indices, output_size=None):
        """Forward function of MaxUnpool2d.
        Args:
            input (Tensor): Tensor needed to upsample.
            indices (Tensor): Indices output of the previous MaxPool.
            output_size (List or Tuple): The shape of output tensor.
                Default: None.
        Returns:
            Tensor: Output tensor.
        """
        return MaxUnpool2dop.apply(input, indices, self.kernel_size,
                                   self.stride, self.padding, output_size)

@titaiwangms
Copy link
Collaborator

Hi @habib-19 ,

We鈥檝e gone ahead and closed this issue because it has a workaround. This workaround can be found here.
If you still believe this issue is relevant, please feel free to reopen the issue and we will triage it as necessary. Please specify in a comment any updated information you may have so that we can address it effectively. We encourage you to try the latest pytorch-preview (nightly) version to see if it has resolved the issue.

Thanks,
ONNX Converter team

@titaiwangms titaiwangms closed this as not planned Won't fix, can't repro, duplicate, stale Oct 24, 2022
@carpemonf
Copy link

carpemonf commented Jan 18, 2023

The workaround in the code shared by @markson14 @garymm didn't work for me when passing an optional output_size like with:

self.unpool = MaxUnpool2d(2, 2, output_size)

The above gives the following error:

  File "/home/user/test/models/MaxUnpool.py", line 178, in forward
    return MaxUnpool2dop.apply(input, indices, self.kernel_size,
RuntimeError: _Map_base::at

I realized that the error came from output_size and while testing I redefined the variable with dim integers inside the forward call. Not sure why yet but this skipped the error and finished the model export. The values and types of output_size are both the same so it looks weird to me.

    def forward(self, input, indices, output_size=None):
        """Forward function of MaxUnpool2d.
        Args:
            input (Tensor): Tensor needed to upsample.
            indices (Tensor): Indices output of the previous MaxPool.
            output_size (List or Tuple): The shape of output tensor.
                Default: None.
        Returns:
            Tensor: Output tensor.
        """
        print(output_size) # torch.Size([1, 32, 21, 9])
        print(type(output_size)) # <class 'torch.Size'>

        # No error with this output_size
        output_size = torch.Size([int(output_size[0]),
                                                 int(output_size[1]),
                                                 int(output_size[2]),
                                                 int(output_size[3])])

        print(output_size_test) # torch.Size([1, 32, 21, 9])
        print(type(output_size_test)) # <class 'torch.Size'>

        return MaxUnpool2dop.apply(input, indices, self.kernel_size,
                                                   self.stride, self.padding, output_size)

I then loaded the exported model with ONNX but if failed at checking inference shapes and in particular because of the /MaxUnpool_output_* dimensions. It gives wrong values of the two latest dimensions, e.g. 20 instead of 21 and 8 instead of 9. Not sure if this is related or not.

Any ideas? Did anyone exported to ONNX using output_size?

Versions: opset 17, 1.13.1+cu117 (tried also 1.13.0a0+git49444c3)

egstatsml pushed a commit to egstatsml/BiSeNet that referenced this issue Feb 22, 2023
@ukoehler
Copy link

Hmm,

I am not sure I understand the problem completely. I am trying to export a pre-trained model to ONNX without the possibility to modify or re-train the model. I have updated to version torch 1.13.1 and I set opset_version=76 when exporting. Trying 18 gave me an error. I still get the following error:

  File "D:\Local\devel\Python\PyTorch\.venv\lib\site-packages\torch\onnx\utils.py", line 1909, in _run_symbolic_function
    raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::max_unpool2d' to 
ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues

Should this work without the workaround now? How can I apply the work-around in my situation?

@MaryamGabitova
Copy link

袩褉懈胁械褌 @habib-19 ,

袦褘 锌芯褕谢懈 写邪谢褜褕械 懈 蟹邪泻褉褘谢懈 褝褌褍 锌褉芯斜谢械屑褍, 锌芯褌芯屑褍 褔褌芯 褍 薪械械 械褋褌褜 芯斜褏芯写薪芯泄 锌褍褌褜. 协褌芯 芯斜褏芯写薪芯械 褉械褕械薪懈械 屑芯卸薪芯 薪邪泄褌懈 蟹写械褋褜.袝褋谢懈 胁褘 胁褋械 械褖械 褋褔懈褌邪械褌械, 褔褌芯 褝褌邪 锌褉芯斜谢械屑邪 邪泻褌褍邪谢褜薪邪, 锌芯卸邪谢褍泄褋褌邪, 薪械 褋褌械褋薪褟泄褌械褋褜 锌芯胁褌芯褉薪芯 芯褌泻褉褘褌褜 锌褉芯斜谢械屑褍, 懈 屑褘 褉邪蟹斜械褉械屑 械械 锌芯 屑械褉械 薪械芯斜褏芯写懈屑芯褋褌懈. 袩芯卸邪谢褍泄褋褌邪, 褍泻邪卸懈褌械 胁 泻芯屑屑械薪褌邪褉懈懈 谢褞斜褍褞 芯斜薪芯胁谢械薪薪褍褞 懈薪褎芯褉屑邪褑懈褞, 泻芯褌芯褉邪褟 褍 胁邪褋 屑芯卸械褌 斜褘褌褜, 褔褌芯斜褘 屑褘 屑芯谐谢懈 褝褎褎械泻褌懈胁薪芯 械械 褍褋褌褉邪薪懈褌褜. 袦褘 褉械泻芯屑械薪写褍械屑 胁邪屑 锌芯锌褉芯斜芯胁邪褌褜 锌芯褋谢械写薪褞褞 胁械褉褋懈褞 pytorch 写谢褟 锌褉械写胁邪褉懈褌械谢褜薪芯谐芯 锌褉芯褋屑芯褌褉邪 (泻邪卸写褍褞 薪芯褔褜), 褔褌芯斜褘 褍蟹薪邪褌褜, 褉械褕懈谢邪 谢懈 芯薪邪 锌褉芯斜谢械屑褍.

小锌邪褋懈斜芯, 泻芯屑邪薪写邪 ONNX Converter

Hello!
Can you tell me, what I should change in code above if I need convert maxunpool1d or maxunpool3d ?

@besserai
Copy link

besserai commented Oct 9, 2023

Hey all!

The workaround is not working for me. Trying to use it leads to RuntimeError: _Map_base::at like in carpemonf's post, but my code is not even using the output_size parameter (self.unpool = MaxUnpool2d(kernel_size=2, stride=2) )

I tried the conversion on PyTorch versions 2.0.1 and 2.1

(Error without the workaround: "torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::max_unpool2d' to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.")

@carpemonf
Copy link

carpemonf commented Oct 9, 2023

Hi @besserai!

I solved the issues with dimensions in my post. I'll check and share it for others.

For your RuntimeError: _Map_base::at error, do you have a reproducible example, or more info about input sizes?

@besserai
Copy link

besserai commented Oct 10, 2023

Hey @carpemonf,
That would be so nice! Thanks!

This is my conversion code:

def convert_to_onnx(
    pth_file_path,
    onnx_file_path,
    batch_size=1,
    input_shape: Tuple[int, int, int] = (3, 512, 512),
):
    # Load the PyTorch model
    checkpoint = torch.load(pth_file_path)
    model = ENet(number_of_classes)
    model.load_state_dict(checkpoint["state_dict"])

    # Set the model to evaluation mode
    model.eval()

    # Create a dummy input tensor for the model with variable batch size
    print(input_shape)
    print((batch_size,))
    batched_input_shape = (batch_size,) + input_shape
    print(batched_input_shape)
    dummy_input = torch.randn(batched_input_shape)

    # Export the model to ONNX format
    torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True)

And I am using this ENet implementation by iArunava, the MaxUnpool2d function is called in the UBNeck block.

I also added my model which I am trying to convert here.

@carpemonf
Copy link

carpemonf commented Oct 10, 2023

@besserai I think you are seeing RuntimeError: _Map_base::at because of x_copy = self.unpool(x_copy, indices, output_size=x.size()) in UBNeck. As you also pass output_size you are likely having the same error I had.

Below is the code I ended up with. It adds the output size to the g.op MaxUnpool for a correct export with targeted output size, plus the hack to avoid the RuntimeError: _Map_base::at. The latter shouldn't be needed but fixes the issue. If someone knows why please share.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn.modules.pooling import _MaxUnpoolNd
from torch.nn.modules.utils import _pair

class MaxUnpool2dop(Function):
    """We warp the `torch.nn.functional.max_unpool2d`
    with an extra `symbolic` method, which is needed while exporting to ONNX.
    Users should not call this function directly.
    """

    @staticmethod
    def forward(ctx, input, indices, kernel_size, stride, padding,
                output_size):
        """Forward function of MaxUnpool2dop.
        Args:
            input (Tensor): Tensor needed to upsample.
            indices (Tensor): Indices output of the previous MaxPool.
            kernel_size (Tuple): Size of the max pooling window.
            stride (Tuple): Stride of the max pooling window.
            padding (Tuple): Padding that was added to the input.
            output_size (List or Tuple): The shape of output tensor.
        Returns:
            Tensor: Output tensor.
        """
        return F.max_unpool2d(input, indices, kernel_size, stride, padding,
                              output_size)


    @staticmethod
    def symbolic(g, input, indices, kernel_size, stride, padding, output_size):
        # get shape
        input_shape = g.op('Shape', input)
        const_0 = g.op('Constant', value_t=torch.tensor(0))
        const_1 = g.op('Constant', value_t=torch.tensor(1))
        output_size_list = list(output_size)
        const_size = g.op('Constant', value_t=torch.tensor(output_size_list))
        batch_size = g.op('Gather', input_shape, const_0, axis_i=0)
        channel = g.op('Gather', input_shape, const_1, axis_i=0)

        # height = (height - 1) * stride + kernel_size
        height = g.op(
            'Gather',
            input_shape,
            g.op('Constant', value_t=torch.tensor(2)),
            axis_i=0)
        height = g.op('Sub', height, const_1)
        height = g.op('Mul', height,
                      g.op('Constant', value_t=torch.tensor(stride[1])))
        height = g.op('Add', height,
                      g.op('Constant', value_t=torch.tensor(kernel_size[1])))

        # width = (width - 1) * stride + kernel_size
        width = g.op(
            'Gather',
            input_shape,
            g.op('Constant', value_t=torch.tensor(3)),
            axis_i=0)
        width = g.op('Sub', width, const_1)
        width = g.op('Mul', width,
                     g.op('Constant', value_t=torch.tensor(stride[0])))
        width = g.op('Add', width,
                     g.op('Constant', value_t=torch.tensor(kernel_size[0])))

        # step of channel
        channel_step = g.op('Mul', height, width)
        # step of batch
        batch_step = g.op('Mul', channel_step, channel)

        # channel offset
        range_channel = g.op('Range', const_0, channel, const_1)
        range_channel = g.op(
            'Reshape', range_channel,
            g.op('Constant', value_t=torch.tensor([1, -1, 1, 1])))
        range_channel = g.op('Mul', range_channel, channel_step)
        range_channel = g.op('Cast', range_channel, to_i=7)  # 7 is int64

        # batch offset
        range_batch = g.op('Range', const_0, batch_size, const_1)
        range_batch = g.op(
            'Reshape', range_batch,
            g.op('Constant', value_t=torch.tensor([-1, 1, 1, 1])))
        range_batch = g.op('Mul', range_batch, batch_step)
        range_batch = g.op('Cast', range_batch, to_i=7)  # 7 is int64

        # update indices
        indices = g.op('Add', indices, range_channel)
        indices = g.op('Add', indices, range_batch)

        return g.op(
            'MaxUnpool',
            input,
            indices,
            const_size,
            kernel_shape_i=kernel_size,
            strides_i=stride)

class MaxUnpool2d(_MaxUnpoolNd):
    """This module is modified from Pytorch `MaxUnpool2d` module.
    Args:
      kernel_size (int or tuple): Size of the max pooling window.
      stride (int or tuple): Stride of the max pooling window.
          Default: None (It is set to `kernel_size` by default).
      padding (int or tuple): Padding that is added to the input.
          Default: 0.
    """

    def __init__(self, kernel_size, stride=None, padding=0):
        super(MaxUnpool2d, self).__init__()
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride or kernel_size)
        self.padding = _pair(padding)

    def forward(self, input, indices, output_size=None):
        """Forward function of MaxUnpool2d.
        Args:
            input (Tensor): Tensor needed to upsample.
            indices (Tensor): Indices output of the previous MaxPool.
            output_size (List or Tuple): The shape of output tensor.
                Default: None.
        Returns:
            Tensor: Output tensor.
        """
        if output_size is not None and isinstance(output_size, torch.Size):
            output_size = tuple(s.item() for s in output_size)

        return MaxUnpool2dop.apply(input, indices, self.kernel_size,
                                   self.stride, self.padding, output_size)

@carpemonf
Copy link

Upon further investigation, I realized that output_size is expected to be a List or Tuple. However, I was passing in a torch.Size object containing torch.Tensor elements. This mismatch likely caused the translation to the expected C++ type to fail, leading to the RuntimeError: _Map_base::at error.

Check if that's the case for you:

print(type(output_size)) # <class 'torch.Size'>
print(type(output_size[0])) # <class 'torch.Tensor'>

I've updated the code I shared above. To address the situation I convert output_size into a tuple inside the custom MaxUnpool2d.

@besserai
Copy link

Hey @carpemonf,

The conversion did the trick! 馃帀 Now the export works, thank you so much! 馃檹
I thought I had to go ahead and redo all my training with MobileNet again...

Indeed I got the same types for the output before (<class 'torch.Size'> and <class 'torch.Tensor'>)

What would be the best place to fix the bug for the future? I could copy the workaround to the ENet repo I used, as many people seem to end up here using that. But better would be to fix it in the Pytorch/ONNX repos, no?

@carpemonf
Copy link

Glad it worked!

I'm not sure about the best place. Ideally this would be a supported operator in Torch instead of the workaround...

How do you guys use the workaround in ENet? I'm doing the above for my model:

  • Import the workaround
from .MaxUnpool import *
  • Define both the custom MaxUnpool2d for exporting to ONNX and the default one:
max_unpool2d_onnx = MaxUnpool2d(kernel_size=2, stride=2);
max_unpool2d = nn.MaxUnpool2d(kernel_size=2, stride=2);
  • Use one or the other:
if torch.onnx.is_in_onnx_export():
    size = tuple(s.item() for s in size)
    x = max_unpool2d_onnx(x, ind, output_size=size)
else:
    x = max_unpool2d(x, ind, output_size=size)

The other option is to put this inside the custom "MaxUnpool" to avoid passing the torch.Size:

    def forward(self, input, indices, output_size=None):
        """Forward function of MaxUnpool2d.
        Args:
            input (Tensor): Tensor needed to upsample.
            indices (Tensor): Indices output of the previous MaxPool.
            output_size (List or Tuple): The shape of output tensor.
                Default: None.
        Returns:
            Tensor: Output tensor.
        """
        if output_size is not None and isinstance(output_size, torch.Size):
            output_size = tuple(s.item() for s in output_size)

        return MaxUnpool2dop.apply(input, indices, self.kernel_size,
            self.stride, self.padding, output_size)

With both you will get a warning because of the size conversion type.

@besserai
Copy link

I use the workaround like you, just switching to the other import when I want to export the model.

Actually, I was not able to use the converted model for inference so far, but that might be an error in my hard to debug opencv.js.
If that persists, I come back and report on that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests