Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Scripted reshape incorrect if shape is dynamically calculated #78721

Open
vbogach opened this issue Jun 2, 2022 · 1 comment
Open
Labels
bug module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vbogach
Copy link

vbogach commented Jun 2, 2022

馃悰 Describe the bug

torch.onnx.export produces incorrect export of reshape function call after scripting if shape is calculated dynamically. It looks like one of the shape arguments is not converted to integer and is float instead.

#sample code 

import torch
import torch.nn as nn
import onnx
import onnxruntime

TEST_WINDOW_SIZE = 7
TEST_H = 28
TEST_W = 28
TEST_B = 1
TEST_NUM_WINDOWS = 16
TEST_C = 96


class Model(nn.Module):
    def __init__(self, window_size: int, H: int, W: int):
        
        super(Model, self).__init__()
        
        self.window_size = window_size
        self.H = H
        self.W = W

    def window_reverse(self, windows, window_size: int, H: int, W: int):
        """
        Args:
            windows: (num_windows*B, window_size, window_size, C)
            window_size (int): Window size
            H (int): Height of image
            W (int): Width of image
        Returns:
            x: (B, H, W, C)
        """
        B:int = int(windows.shape[0] / (H * W / window_size / window_size))
        x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
        x = x.permute(0, 1, 3, 2, 4, 5)
        x = x.reshape(B, H, W, -1)
        return x

    def forward(self, windows):
        return self.window_reverse(windows, self.window_size, self.H, self.W)


model = Model(TEST_WINDOW_SIZE, TEST_H, TEST_W)
model.eval()
model.cpu()

windows = torch.randn(TEST_NUM_WINDOWS * TEST_B, TEST_WINDOW_SIZE, TEST_WINDOW_SIZE, TEST_C)

jit_model = torch.jit.script(model, example_inputs=[(windows,)])
jit_model.eval()
jit_model.cpu()

torch.testing.assert_allclose(model(windows), jit_model(windows))

torch.onnx.export(jit_model,
    windows,
    "bug.onnx",
    export_params=False,
    opset_version=11,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        'input' : {0 : 'batch_size'}, 
        'output' : {0 : 'batch_size'}
    }, verbose=True)

onnx_model = onnx.load("bug.onnx")
onnx.checker.check_model(onnx_model)

ort_session = onnxruntime.InferenceSession("bug.onnx", verbose=True) 

error

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from bug.onnx failed:Type Error: Type parameter (T) of Optype (Concat) bound to different types (tensor(float) and tensor(int64) in node (Concat_11).

bug onnx

Versions

PyTorch version: 1.11.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.9.12 (main, Apr 5 2022, 06:56:58) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.15.0-25-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.5.119
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 Ti
Nvidia driver version: 495.29.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.4
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.4
[pip3] torch==1.11.0
[pip3] torchvision==0.12.0
[conda] numpy 1.22.4 pypi_0 pypi
[conda] torch 1.11.0 pypi_0 pypi
[conda] torchvision 0.12.0 pypi_0 pypi

@justinchuby justinchuby added module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team labels Jun 2, 2022
@justinchuby
Copy link
Collaborator

Thanks for reporting this issue! We are able to reproduce it.

@justinchuby justinchuby changed the title torch.onnx.export error in reshape function after scripting [ONNX] Scripted reshape is exported incorrectly if shape is dynamically calculated Jun 2, 2022
@justinchuby justinchuby changed the title [ONNX] Scripted reshape is exported incorrectly if shape is dynamically calculated [ONNX] Scripted reshape incorrect if shape is dynamically calculated Jun 2, 2022
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Inbox
Development

No branches or pull requests

4 participants