Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Fix flatten operator #45632

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions test/onnx/test_models.py
Expand Up @@ -105,7 +105,6 @@ def test_alexnet(self):
)
self.exportTest(toC(alexnet()), toC(x))

@disableScriptTest()
def test_mnist(self):
x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
self.exportTest(toC(MNIST()), toC(x))
Expand Down Expand Up @@ -158,7 +157,6 @@ def test_squeezenet(self):
sqnet_v1_1 = SqueezeNet(version=1.1)
self.exportTest(toC(sqnet_v1_1), toC(x))

@disableScriptTest()
def test_densenet(self):
# Densenet-121 model
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
Expand Down
13 changes: 6 additions & 7 deletions torch/onnx/symbolic_opset11.py
Expand Up @@ -725,18 +725,17 @@ def narrow(g, input, dim, start, length):
@parse_args('v', 'i', 'i')
def flatten(g, input, start_dim, end_dim):
dim = input.type().dim()
if dim is None:
return _unimplemented("dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.")

# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim == 1:
if (end_dim == -1 or end_dim == dim - 1):
if (end_dim == -1 or (dim is not None and end_dim == dim - 1)):
return g.op("Flatten", input, axis_i=start_dim)
elif start_dim == 0:
if (end_dim == -2 or end_dim == dim - 2):
if (end_dim == -2 or (dim is not None and end_dim == dim - 2)):
return g.op("Flatten", input, axis_i=end_dim + 1)
if dim is None:
return _unimplemented("dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.")
# if end_dim is negative add dim
if end_dim < 0 :
end_dim = dim + end_dim
Expand Down