diff --git a/test/onnx/test_models.py b/test/onnx/test_models.py index f91f6bea165b..0613c69c0867 100644 --- a/test/onnx/test_models.py +++ b/test/onnx/test_models.py @@ -105,7 +105,6 @@ def test_alexnet(self): ) self.exportTest(toC(alexnet()), toC(x)) - @disableScriptTest() def test_mnist(self): x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0)) self.exportTest(toC(MNIST()), toC(x)) @@ -158,7 +157,6 @@ def test_squeezenet(self): sqnet_v1_1 = SqueezeNet(version=1.1) self.exportTest(toC(sqnet_v1_1), toC(x)) - @disableScriptTest() def test_densenet(self): # Densenet-121 model x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 83b0da0aef5d..595ebffb2901 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -725,18 +725,17 @@ def narrow(g, input, dim, start, length): @parse_args('v', 'i', 'i') def flatten(g, input, start_dim, end_dim): dim = input.type().dim() - if dim is None: - return _unimplemented("dim", - "ONNX and PyTorch use different strategies to split the input. " - "Input rank must be known at export time.") - # use ONNX's Flatten operator for cases where the output shape is 2D if start_dim == 1: - if (end_dim == -1 or end_dim == dim - 1): + if (end_dim == -1 or (dim is not None and end_dim == dim - 1)): return g.op("Flatten", input, axis_i=start_dim) elif start_dim == 0: - if (end_dim == -2 or end_dim == dim - 2): + if (end_dim == -2 or (dim is not None and end_dim == dim - 2)): return g.op("Flatten", input, axis_i=end_dim + 1) + if dim is None: + return _unimplemented("dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.") # if end_dim is negative add dim if end_dim < 0 : end_dim = dim + end_dim