Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Unsupported: ONNX export of transpose for tensor of unknown rank." with dynamic axes #50686

Closed
ghost opened this issue Jan 18, 2021 · 8 comments
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ghost
Copy link

ghost commented Jan 18, 2021

馃悰 Bug

When exporting a model with dynamic_axes, I get an error Unsupported: ONNX export of transpose for tensor of unknown rank. although the tensor itself looks correct. The model uses transpose(1,2) which is the source of the error. Please see the sample below.

To Reproduce

Steps to reproduce the behavior:

import torch

class Model(torch.nn.Module):
    def forward(self, x):
        xt = x.transpose(1,2)
        return xt

m  = Model().cuda()
i = (torch.randn(1,2,3).cuda(),)

torch.onnx.export(m, i, "model.onnx",
                  input_names=["INPUT_0"],
                  output_names=["OUTPUT_0"],
                  dynamic_axes={"INPUT_0": {0: "batch_size"},
                                "OUTPUT_0": {0: "batch_size"}})

Expected behavior

Successful model export to ONNX.

Environment

  • PyTorch Version: 1.8.0a0+1606899
  • OS: Ubuntu 20.04.1 LTS (x86_64)
  • How you installed PyTorch: from docker nvcr.io/nvidia/pytorch:20.12-py3
  • Python version: 3.8 (64-bit runtime)
  • CUDA/cuDNN version: 11.1/8.0.5
  • GPU models and configuration: GPU 0: TITAN V

cc @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof

@ngimel ngimel added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 18, 2021
@neginraoof
Copy link
Contributor

Thanks for reporting this issue. Transpose shape inference logic has been recently updated in (#50163), and looks like this issue is resolved in the nightly build.
Can you confirm which build you are using? Would it be possible to test the latest nightly build?
Thanks.

@spandantiwari
Copy link

Thanks @neginraoof .

@grzegorzkarchnv - please reopen this issue if you still see this problem.

@ghost
Copy link
Author

ghost commented Jan 28, 2021

I confirm the issue is resolved and I can export the model successfully. Thanks for your support!

@spandantiwari
Copy link

Thanks for confirming @grzegorzkarchnv .

@Aloento
Copy link

Aloento commented Aug 17, 2022

Same issue for the ONNX 16
RuntimeError: Unsupported: ONNX export of transpose for tensor of unknown rank.
1.12.1+cu116

@N1njaG
Copy link

N1njaG commented Sep 26, 2022

Same issue for the ONNX 16 RuntimeError: Unsupported: ONNX export of transpose for tensor of unknown rank. 1.12.1+cu116

Have you solved the issue锛焌nd how did you do

@fernandorovai
Copy link

I'm facing the same problem when transposing a tensor inside a for loop. Any idea on how to fix it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants
@fernandorovai @Aloento @ngimel @neginraoof @spandantiwari @N1njaG and others