-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature request] Support output size parameter in Upsample ONNX operator #7732
Comments
Dynamic upsampling is a PR on onnx side: onnx/onnx#1467 |
Any update on this, seeing that the dependent PR has been merged? The code from above still produces an error with the latest nightly conda package
Running the code with 0.41 gives me Lastly, if I run the code with 0.41 and mode='nearest', I get So it seems that this is still a persistent problem. |
I'm experiencing the same problem when exporting from PyTorch 1.0.0 to ONNX to CoreML. Any known solution? |
I'm experiencing the same issue, and by the way, the nightly version of '1.2.0.dev20190725+cpu' raise another fault: |
@Kay-Tian this will be fixed in tomorrow's nightlies |
import torch
from torch import nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)
def forward(self, x):
h1 = self.conv1(x)
h2 = nn.functional.interpolate(x, size=h1.size()[-2:], mode='bilinear', align_corners=False)
h = torch.cat([h1, h2], dim=1)
return h
model = SimpleNet()
x = torch.randn(1, 3, 10, 10).cpu()
torch.onnx.export(model, x, "/dev/null", verbose=True, opset_version=13) If it's not correct, what if you replace the last line with?
|
Closing due to lack of replies. If this is still an issue please open a new issue. |
Right now Upsample in onnx supports fixed (calculated for specific input size) scale factors. Which are calculated during graph tracing based on current input and output sizes.
At the same time F.upsample allows to specify exact output size, which is very convenient if we want to evaluate models with backward connections (like Feature Pyramid Network) on inputs with varying sizes.
Could we extend onnx Upsample operator to support desired output size (exactly like in F.upsample), taken as additional input to the operator?
Simple example to check (this should export the graph with upsample taking its output size dynamically as h1.size()[-2:]):
cc @BowenBao @neginraoof
The text was updated successfully, but these errors were encountered: