Skip to content

Commit

Permalink
[ONNX] Fix flatten operator (#45632)
Browse files Browse the repository at this point in the history
Summary:
Even when dim is None, there are cases when flatten can be exported.
Also enable test_densenet in scripting mode

Pull Request resolved: #45632

Reviewed By: VitalyFedyunin

Differential Revision: D24116994

Pulled By: bzinodev

fbshipit-source-id: 76da6c073ddf79bba64397fd56b592de850034c4
  • Loading branch information
KsenijaS authored and facebook-github-bot committed Oct 14, 2020
1 parent d655341 commit 6ca03ae
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
2 changes: 0 additions & 2 deletions test/onnx/test_models.py
Expand Up @@ -105,7 +105,6 @@ def test_alexnet(self):
)
self.exportTest(toC(alexnet()), toC(x))

@disableScriptTest()
def test_mnist(self):
x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
self.exportTest(toC(MNIST()), toC(x))
Expand Down Expand Up @@ -158,7 +157,6 @@ def test_squeezenet(self):
sqnet_v1_1 = SqueezeNet(version=1.1)
self.exportTest(toC(sqnet_v1_1), toC(x))

@disableScriptTest()
def test_densenet(self):
# Densenet-121 model
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
Expand Down
13 changes: 6 additions & 7 deletions torch/onnx/symbolic_opset11.py
Expand Up @@ -725,18 +725,17 @@ def narrow(g, input, dim, start, length):
@parse_args('v', 'i', 'i')
def flatten(g, input, start_dim, end_dim):
dim = input.type().dim()
if dim is None:
return _unimplemented("dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.")

# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim == 1:
if (end_dim == -1 or end_dim == dim - 1):
if (end_dim == -1 or (dim is not None and end_dim == dim - 1)):
return g.op("Flatten", input, axis_i=start_dim)
elif start_dim == 0:
if (end_dim == -2 or end_dim == dim - 2):
if (end_dim == -2 or (dim is not None and end_dim == dim - 2)):
return g.op("Flatten", input, axis_i=end_dim + 1)
if dim is None:
return _unimplemented("dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.")
# if end_dim is negative add dim
if end_dim < 0 :
end_dim = dim + end_dim
Expand Down

0 comments on commit 6ca03ae

Please sign in to comment.