Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to export an ONNX attribute 'onnx::Gather' Bug #34780

Closed
ghost opened this issue Mar 15, 2020 · 4 comments
Closed

Failed to export an ONNX attribute 'onnx::Gather' Bug #34780

ghost opened this issue Mar 15, 2020 · 4 comments
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ghost
Copy link

ghost commented Mar 15, 2020

🐛 Bug

While trying to convert resnet18 based pytorch model from here. Below error comes up:

RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible

To Reproduce

Steps to reproduce the behavior:

  1. Download this repo.
    Also download the model weights 79999_iter.pth from the repo page and place it in res/cp.
  2. In Google Collab run conversion script:

from model import BiSeNet
import torch.onnx
import torch

net = BiSeNet(19)
net.cuda()
net.load_state_dict(torch.load('/content/drive/My Drive/Collab/fp/res/cp/79999_iter.pth'))
net.eval()

dummy = torch.rand(1,3,512,512).cuda()
torch.onnx.export(net, dummy, "Model.onnx", input_names=["image"], output_names=["output"])

Expected behavior

Should not produce any errors during conversion

Environment

collect_env output:

Collecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.12.0

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 418.67
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.17.5
[pip3] torch==1.4.0
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.5.0
[conda] Could not collect

Additional context

I added print (v.node ()) to symbolic_helper.py just before the runtime error is raised to see what's causing the error.

This is the output: %595 : Long() = onnx::Gather[axis=0](%592, %594) # /content/drive/My Drive/Collab/fp/model.py:111:0

And that line in 111 in model.py is: avg = F.avg_pool2d(feat32, feat32.size()[2:])

Based on my further research I found this source stating that:

Both resNet50 and 32 are fine, but to resNet18, ONNX model cannot be exported.

The source suggests the following changes:

From this:

import torch.nn.functional as F
def forward(self, x):
feat = self.base(x)
feat = F.avg_pool2d(feat, feat.size()[2:])

To this:

class Model(nn.Module):
def init():

self.avg_pool2d = nn.AvgPool2d(kernel_size=k_s, ceil_mode=False)

def forward(self, x):

feat = self.avg_pool2d(feat, feat.size()[2:])

This change however yield other errors

cc @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof

@VitalyFedyunin VitalyFedyunin added the module: onnx Related to torch.onnx label Mar 19, 2020
@spandantiwari
Copy link

@Malemo - I see you have another issue similar to this #34743. Is this a duplicate of that?

@yf225 yf225 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 20, 2020
@juimoisnono
Copy link

l have just solved a similar problem like yours, and i suggest you change "feat", "feat.size()[2:]"to constant, may be it works

@ysnan
Copy link

ysnan commented Jul 9, 2021

Elaborate a little bit more.
Change instances of

x = F.avg_pool2d(x, x.shape[2:])

to

x_shape = [int(s) for s in x.shape[2:]]
x = F.avg_pool2d(x, x_shape)

make sure the x_shape is the list of int instead of torch.Size().

@garymm
Copy link
Collaborator

garymm commented Mar 1, 2022

Duplicate of #34743.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants