Skip to content

Commit

Permalink
[ONNX] Update squeeze test for opset 9 (#45369)
Browse files Browse the repository at this point in the history
Summary:
Only under static axes does opset 9 supports no-op squeeze when dim is not 1.
Updating the test case where it was setting dynamic axes.

Pull Request resolved: #45369

Reviewed By: anjali411

Differential Revision: D24280180

Pulled By: bzinodev

fbshipit-source-id: d7cda88ab338a1c41a68052831dcebe739a3843c
  • Loading branch information
BowenBao authored and facebook-github-bot committed Oct 14, 2020
1 parent 6ca03ae commit b28b5d3
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 22 deletions.
13 changes: 7 additions & 6 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,14 +713,19 @@ def forward(self, x):
return torch.squeeze(x)

x2 = [] if x2 is None else [x2]
self.run_test(Squeeze(d), x1, input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, test_with_inputs=x2)
if len(x2) > 0:
self.run_test(Squeeze(d), x1,
input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}},
test_with_inputs=x2)
else:
self.run_test(Squeeze(d), x1)

def test_squeeze_without_no_op(self):
x = torch.randn(2, 1, 4)
self.squeeze_model_tests(1, x, None)

@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze(self):
def test_squeeze_dynamic(self):
x_squeeze = torch.randn(2, 1, 4)
x_noop = torch.randn(2, 2, 3)
self.squeeze_model_tests(1, x_squeeze, x_noop)
Expand All @@ -746,10 +751,6 @@ def test_squeeze_no_op(self):
x_squeeze = torch.randn(2, 2, 1)
self.squeeze_model_tests(2, x_noop, x_squeeze)

def test_squeeze_no_op_without_additional_inputs(self):
x_noop = torch.randn(2, 1, 4)
self.squeeze_model_tests(2, x_noop, None)

@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze_runtime_dim(self):
class Squeeze(torch.nn.Module):
Expand Down
25 changes: 24 additions & 1 deletion torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,34 @@ Node* CloneNodeToGraph(Node* n, std::shared_ptr<Graph> n_graph) {
// prim::ListConstruct is converted to onnx::Concat. The conversion should
// eventually be moved to symbolic. For now, treat this operator as
// special case, and change from list type to tensor type. The scalar type
// is preserved.
// is preserved. If the elemtype is Int, insert a onnx::Concat node into
// the graph.
TypePtr elem = v->type()->cast<ListType>()->getElementType();
c10::optional<at::ScalarType> scalar_type = c10::nullopt;
if (elem->cast<IntType>()) {
scalar_type = at::kLong;

auto lc_node = v->node();
// ListConstruct Int[] output case, we need to transform to ONNX
// Concat to ensure the output is a single tensor(dynamic) type in
// order to be consumed as inputs
std::vector<Value*> unsqueezed;
for (auto* input : lc_node->inputs()) {
Node* unsqueezed_node =
n_graph->insertNode(n_graph->create(::c10::onnx::Unsqueeze, 1));
auto new_input = n_graph->addInput();
new_input->copyMetadata(input);
unsqueezed_node->addInput(new_input);
unsqueezed_node->is_(attr::axes, {0});
unsqueezed.emplace_back(unsqueezed_node->output());
}
Node* concat_node =
n_graph->insertNode(n_graph->create(::c10::onnx::Concat, 1));
concat_node->i_(attr::axis, 0);
for (auto v : unsqueezed) {
concat_node->addInput(v);
}
return concat_node->output();
} else if (elem->cast<FloatType>()) {
scalar_type = at::kFloat;
} else if (elem->cast<BoolType>()) {
Expand Down
47 changes: 32 additions & 15 deletions torch/onnx/symbolic_opset11.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,21 +538,38 @@ def squeeze(g, self, dim=None):
return g.op("Squeeze", self)

dim = sym_help._get_const(dim, 'i', 'dim')
# create 'cond' node (condition is shape[i]==1)
dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
size = sym_help._size_helper(g, self, dim_constant)
const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
cond = g.op("Equal", size, const_one)
# create the 'If' node and add the 'then' and 'else' blocks to it.
if_node_outputs = g.op("If", cond)
if_node = if_node_outputs.node()
if_block = torch.onnx.utils._add_block(if_node)
squeeze_ = if_block.op("Squeeze", self, axes_i=[dim])
torch.onnx.utils._add_output_to_block(if_block, squeeze_)
else_block = torch.onnx.utils._add_block(if_node)
identity_ = else_block.op("Identity", self)
torch.onnx.utils._add_output_to_block(else_block, identity_)
return if_node_outputs

input_shape = self.type().sizes()
from torch.onnx.symbolic_helper import _onnx_shape_inference
if input_shape is None or not _onnx_shape_inference:
# If onnx shape inference is not on, export always as dynamic.
# Because we cannot tell if observed static shape is also static at runtime.
# create 'cond' node (condition is shape[i]==1)
dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
size = sym_help._size_helper(g, self, dim_constant)
const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
cond = g.op("Equal", size, const_one)
# create the 'If' node and add the 'then' and 'else' blocks to it.
if_node_outputs = g.op("If", cond)
if_node = if_node_outputs.node()
if_block = torch.onnx.utils._add_block(if_node)
squeeze_ = if_block.op("Squeeze", self, axes_i=[dim])
torch.onnx.utils._add_output_to_block(if_block, squeeze_)
else_block = torch.onnx.utils._add_block(if_node)
identity_ = else_block.op("Identity", self)
torch.onnx.utils._add_output_to_block(else_block, identity_)
return if_node_outputs

# For static input shape
if dim < 0:
dim += self.type().dim()
if input_shape[dim] > 1:
warnings.warn("This model contains a squeeze operation on dimension " + str(dim) + ". The size of " +
"this dimension in the given input is " + str(input_shape[dim]) + ". The model will " +
"be exported without the squeeze node. If the model is intended to be used with dynamic " +
"input shapes, please export with dynamic_axes argument.")
return self
return g.op("Squeeze", self, axes_i=[dim])


@parse_args('v', 'i')
Expand Down

0 comments on commit b28b5d3

Please sign in to comment.