Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Update squeeze test for opset 9 #45369

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 7 additions & 6 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,14 +713,19 @@ def forward(self, x):
return torch.squeeze(x)

x2 = [] if x2 is None else [x2]
self.run_test(Squeeze(d), x1, input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, test_with_inputs=x2)
if len(x2) > 0:
self.run_test(Squeeze(d), x1,
input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}},
test_with_inputs=x2)
else:
self.run_test(Squeeze(d), x1)

def test_squeeze_without_no_op(self):
x = torch.randn(2, 1, 4)
self.squeeze_model_tests(1, x, None)

@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze(self):
def test_squeeze_dynamic(self):
x_squeeze = torch.randn(2, 1, 4)
x_noop = torch.randn(2, 2, 3)
self.squeeze_model_tests(1, x_squeeze, x_noop)
Expand All @@ -746,10 +751,6 @@ def test_squeeze_no_op(self):
x_squeeze = torch.randn(2, 2, 1)
self.squeeze_model_tests(2, x_noop, x_squeeze)

def test_squeeze_no_op_without_additional_inputs(self):
x_noop = torch.randn(2, 1, 4)
self.squeeze_model_tests(2, x_noop, None)

@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze_runtime_dim(self):
class Squeeze(torch.nn.Module):
Expand Down
25 changes: 24 additions & 1 deletion torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,34 @@ Node* CloneNodeToGraph(Node* n, std::shared_ptr<Graph> n_graph) {
// prim::ListConstruct is converted to onnx::Concat. The conversion should
// eventually be moved to symbolic. For now, treat this operator as
// special case, and change from list type to tensor type. The scalar type
// is preserved.
// is preserved. If the elemtype is Int, insert a onnx::Concat node into
// the graph.
TypePtr elem = v->type()->cast<ListType>()->getElementType();
c10::optional<at::ScalarType> scalar_type = c10::nullopt;
if (elem->cast<IntType>()) {
scalar_type = at::kLong;

auto lc_node = v->node();
// ListConstruct Int[] output case, we need to transform to ONNX
// Concat to ensure the output is a single tensor(dynamic) type in
// order to be consumed as inputs
std::vector<Value*> unsqueezed;
for (auto* input : lc_node->inputs()) {
Node* unsqueezed_node =
n_graph->insertNode(n_graph->create(::c10::onnx::Unsqueeze, 1));
auto new_input = n_graph->addInput();
new_input->copyMetadata(input);
unsqueezed_node->addInput(new_input);
unsqueezed_node->is_(attr::axes, {0});
unsqueezed.emplace_back(unsqueezed_node->output());
}
Node* concat_node =
n_graph->insertNode(n_graph->create(::c10::onnx::Concat, 1));
concat_node->i_(attr::axis, 0);
for (auto v : unsqueezed) {
concat_node->addInput(v);
}
return concat_node->output();
} else if (elem->cast<FloatType>()) {
scalar_type = at::kFloat;
} else if (elem->cast<BoolType>()) {
Expand Down
47 changes: 32 additions & 15 deletions torch/onnx/symbolic_opset11.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,21 +538,38 @@ def squeeze(g, self, dim=None):
return g.op("Squeeze", self)

dim = sym_help._get_const(dim, 'i', 'dim')
# create 'cond' node (condition is shape[i]==1)
dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
size = sym_help._size_helper(g, self, dim_constant)
const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
cond = g.op("Equal", size, const_one)
# create the 'If' node and add the 'then' and 'else' blocks to it.
if_node_outputs = g.op("If", cond)
if_node = if_node_outputs.node()
if_block = torch.onnx.utils._add_block(if_node)
squeeze_ = if_block.op("Squeeze", self, axes_i=[dim])
torch.onnx.utils._add_output_to_block(if_block, squeeze_)
else_block = torch.onnx.utils._add_block(if_node)
identity_ = else_block.op("Identity", self)
torch.onnx.utils._add_output_to_block(else_block, identity_)
return if_node_outputs

input_shape = self.type().sizes()
from torch.onnx.symbolic_helper import _onnx_shape_inference
if input_shape is None or not _onnx_shape_inference:
# If onnx shape inference is not on, export always as dynamic.
# Because we cannot tell if observed static shape is also static at runtime.
# create 'cond' node (condition is shape[i]==1)
dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
size = sym_help._size_helper(g, self, dim_constant)
const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
cond = g.op("Equal", size, const_one)
# create the 'If' node and add the 'then' and 'else' blocks to it.
if_node_outputs = g.op("If", cond)
if_node = if_node_outputs.node()
if_block = torch.onnx.utils._add_block(if_node)
squeeze_ = if_block.op("Squeeze", self, axes_i=[dim])
torch.onnx.utils._add_output_to_block(if_block, squeeze_)
else_block = torch.onnx.utils._add_block(if_node)
identity_ = else_block.op("Identity", self)
torch.onnx.utils._add_output_to_block(else_block, identity_)
return if_node_outputs

# For static input shape
if dim < 0:
dim += self.type().dim()
if input_shape[dim] > 1:
warnings.warn("This model contains a squeeze operation on dimension " + str(dim) + ". The size of " +
"this dimension in the given input is " + str(input_shape[dim]) + ". The model will " +
"be exported without the squeeze node. If the model is intended to be used with dynamic " +
"input shapes, please export with dynamic_axes argument.")
return self
return g.op("Squeeze", self, axes_i=[dim])


@parse_args('v', 'i')
Expand Down