Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] Support output size parameter in Upsample ONNX operator #7732

Closed
aachigorin opened this issue May 21, 2018 · 7 comments
Closed
Labels
feature A request for a proper, new feature. module: onnx Related to torch.onnx onnx-needs-info needs information from the author / reporter before ONNX team can take action triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@aachigorin
Copy link

aachigorin commented May 21, 2018

Right now Upsample in onnx supports fixed (calculated for specific input size) scale factors. Which are calculated during graph tracing based on current input and output sizes.

At the same time F.upsample allows to specify exact output size, which is very convenient if we want to evaluate models with backward connections (like Feature Pyramid Network) on inputs with varying sizes.

Could we extend onnx Upsample operator to support desired output size (exactly like in F.upsample), taken as additional input to the operator?

Simple example to check (this should export the graph with upsample taking its output size dynamically as h1.size()[-2:]):

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(3, 8, kernel_size=3, stride=2, padding=1)
        
    def forward(self, x):
        h1 = self.conv1(x)
        h2 = F.upsample(x, size=h1.size()[-2:], mode='bilinear', align_corners=False)        
        h = torch.cat([h1, h2], dim=1)
        return h


model = SimpleNet()
save_path = '/Users/aachigorin/temp/test_bilinear_onnx_export.txt'
x = Variable(torch.randn(1, 3, 10, 10)).cpu()
torch_out = torch.onnx._export(model, x, save_path, export_params=True, verbose=True)
print('finished onnx export')

model = onnx.load(save_path)
print('check = {}'.format(onnx.checker.check_model(model)))
print(onnx.helper.printable_graph(model.graph))

cc @BowenBao @neginraoof

@zou3519 zou3519 added the module: onnx Related to torch.onnx label May 21, 2018
@ahirner
Copy link
Contributor

ahirner commented Oct 15, 2018

Dynamic upsampling is a PR on onnx side: onnx/onnx#1467

@wadimkehl
Copy link

wadimkehl commented Dec 4, 2018

Any update on this, seeing that the dependent PR has been merged? The code from above still produces an error with the latest nightly conda package 1.0.0.dev20181204-py3.6_cuda9.0.176_cudnn7.4.1_0

raise RuntimeError("ONNX symbolic expected a constant value in the trace")

Running the code with 0.41 gives me
ONNX export failed on upsample_bilinear2d because align_corners == True not supported
although align_corners=False. This is probably another bug.

Lastly, if I run the code with 0.41 and mode='nearest', I get
TypeError: 'torch._C.Value' object does not support indexing (occurred when translating upsample_nearest2d)

So it seems that this is still a persistent problem.

@glenn-jocher
Copy link

I'm experiencing the same problem when exporting from PyTorch 1.0.0 to ONNX to CoreML. Any known solution?

@Kay-Tian
Copy link

I'm experiencing the same issue, and by the way, the nightly version of '1.2.0.dev20190725+cpu' raise another fault:
/anaconda3/envs/torch-nightly/lib/python3.6/site-packages/torch/nn/modules/conv.py(342)conv2d_forward() 341 return F.conv2d(input, weight, self.bias, self.stride, --> 342 self.padding, self.dilation, self.groups) 343 Illegal instruction (core dumped)

@soumith
Copy link
Member

soumith commented Jul 25, 2019

@Kay-Tian this will be fixed in tomorrow's nightlies

@izdeby izdeby added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module feature A request for a proper, new feature. labels Jul 26, 2019
@garymm
Copy link
Collaborator

garymm commented Aug 13, 2021

F.upsample has been deprecated in favor of torch.nn.functional.interpolate, and the Upsample ONNX op has been deprecated since ONNX op set 10. The exporter now emits Resize instead.
Does this code produce the output you are asking for with the latest nightly pytorch? If not, could you please explain what you expect to see vs what you see?

import torch
from torch import nn

class SimpleNet(nn.Module):                                                                                                                    
    def __init__(self):                                                                                                                        
        super(SimpleNet, self).__init__()                     
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)                                                                       
                                                                                                                                                                                                                                                                                                        
    def forward(self, x):                                      
        h1 = self.conv1(x)                                                                                                                     
        h2 = nn.functional.interpolate(x, size=h1.size()[-2:], mode='bilinear', align_corners=False)                                              
        h = torch.cat([h1, h2], dim=1)                                                                                                         
        return h                                                
                                                                       
                                                                                                                                                                                                                                                                                                           
model = SimpleNet()                                                                                                                                                                                                                                                                                    
x = torch.randn(1, 3, 10, 10).cpu()                                                                                                            
torch.onnx.export(model, x, "/dev/null", verbose=True, opset_version=13)

If it's not correct, what if you replace the last line with?

torch.onnx.export(torch.jit.script(model), x, "/dev/null", example_outputs=model(x), verbose=True, opset_version=13)

@garymm garymm added the onnx-needs-info needs information from the author / reporter before ONNX team can take action label Aug 13, 2021
@garymm
Copy link
Collaborator

garymm commented Oct 16, 2021

Closing due to lack of replies. If this is still an issue please open a new issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: onnx Related to torch.onnx onnx-needs-info needs information from the author / reporter before ONNX team can take action triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants