Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Enable remaining failed tests in opset13 #50806

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 2 additions & 17 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,6 @@ def get_test_images(self):

return [image], [image2]

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # Faster RCNN model is not scriptable
def test_faster_rcnn(self):
Expand Down Expand Up @@ -487,7 +486,6 @@ def test_paste_mask_in_image(self):

assert torch.all(out2.eq(out_trace2))

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_mask_rcnn(self):
Expand Down Expand Up @@ -532,7 +530,6 @@ def test_heatmaps_to_keypoints(self):
assert torch.all(out2[0].eq(out_trace2[0]))
assert torch.all(out2[1].eq(out_trace2[1]))

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_keypoint_rcnn(self):
Expand All @@ -554,7 +551,6 @@ def test_keypoint_rcnn(self):
dynamic_axes={"images_tensors": [0, 1, 2]},
rtol=5e-3, atol=1e-5)

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_shufflenet_v2_dynamic_axes(self):
Expand Down Expand Up @@ -1241,8 +1237,8 @@ def forward(self, x):
x = torch.randn(2, 3, 4)
self.run_test(FloatingPoint(), x)

@unittest.skip("If operator rank mismatch between outputs of two branches.")
@skipIfUnsupportedMinOpsetVersion(9)
# Operator rank mismatch between outputs of two branches for opsets below 11.
@skipIfUnsupportedMinOpsetVersion(11)
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
@skipIfONNXShapeInference(False)
def test_floating_point_infer_dtype(self):
class FloatingPoint(torch.jit.ScriptModule):
Expand Down Expand Up @@ -1795,7 +1791,6 @@ def forward(self, hidden_states):
dynamic_axes={'x': {0: 'seq_length', 1: 'batch_size'}}, test_with_inputs=[y])

@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([13])
def test_copy_(self):
class CopyModel(torch.nn.Module):
def forward(self, x, data):
Expand Down Expand Up @@ -3602,7 +3597,6 @@ def forward(self, x, y, z, ind):
ind = torch.tensor(-2, dtype=torch.long)
self.run_test(GetItemModel(), (x, y, z, ind))

@skipIfUnsupportedOpsetVersion([13])
@disableScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable.
@skipIfUnsupportedMinOpsetVersion(9)
def test_nonzero(self):
Expand Down Expand Up @@ -3834,7 +3828,6 @@ def forward(self, x):
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
self.run_test(model, inputs)

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_loop_with_list(self):
class ListLoopModel(torch.jit.ScriptModule):
Expand Down Expand Up @@ -4382,7 +4375,6 @@ def forward(self, input):
x = torch.randint(10, (2, 3))
self.run_test(FModModel(), x)

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(9)
def test_glu(self):
class GluModel(torch.nn.Module):
Expand Down Expand Up @@ -5476,7 +5468,6 @@ def forward(self, cond, input, other):
z = torch.ones(2, 3, 1)
self.run_test(Model(), (x, y, z))

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(9)
def test_where_condition(self):
class Model1(torch.nn.Module):
Expand Down Expand Up @@ -6070,7 +6061,6 @@ def forward(self, x, y):
"ScriptModel - Initializers' sequence is not as same as named_parameters(). Expected: (" \
+ ', '.join(named_params_list) + "). Actual:(" + ', '.join(actual_list) + ")."

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_nms(self):
boxes = torch.rand(5, 4)
Expand Down Expand Up @@ -6100,15 +6090,13 @@ def forward(self, boxes, size):
dynamic_axes={"size": [0, 1]},
test_with_inputs=[(boxes, size), (boxes, size_2)])

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2)
self.run_test(model, (x, single_roi))

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_roi_align_aligned(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
Expand Down Expand Up @@ -6183,7 +6171,6 @@ def get_features(self, images):
features = OrderedDict(features)
return features

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_rpn(self):

Expand Down Expand Up @@ -6212,7 +6199,6 @@ def forward(self, images, features):
test_with_inputs=[(images, features), (images2, test_features)],
dict_check=False)

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_multi_scale_roi_align(self):

Expand All @@ -6239,7 +6225,6 @@ def forward(self, input, boxes):

self.run_test(TransformModule(), (i, [boxes],), test_with_inputs=[(i, [boxes],), (i1, [boxes1],)])

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,10 @@ def __interpolate_helper(g, input, size, scale_factor, mode, align_corners, reco
def _unbind_helper(g, self, dim, _outputs):
if _export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import unbind
else:
elif _export_onnx_opset_version <= 12:
from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef]
else:
from torch.onnx.symbolic_opset13 import unbind # type: ignore[no-redef]
return unbind(g, self, dim, _outputs)


Expand Down
19 changes: 15 additions & 4 deletions torch/onnx/symbolic_opset13.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero


# EDITING THIS FILE? READ THIS FIRST!
Expand Down Expand Up @@ -106,9 +106,20 @@ def unbind(g, self, dim=0, _outputs=None):
return squeezed_outputs


def glu(g, input, dim):
first, second = g.op('Split', input, dim, outputs=2)
return g.op('Mul', first, g.op('Sigmoid', second))
# Emitted from `torch.nonzero(x, as_tuple=True)`
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
def nonzero_numpy(g, input, _outputs=None):
return unbind(g, nonzero(g, input), 1, _outputs=_outputs)


@parse_args('v', 'v', 'v', 'i')
def where(g, condition, self=None, other=None, _outputs=None):
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
if condition.type().scalarType() != 'Bool':
condition = g.op("Cast", condition, to_i=sym_help.cast_pytorch_to_onnx['Bool'])
if self is None:
condition = nonzero(g, condition)
return sym_help._unbind_helper(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs)
return g.op("Where", condition, self, other)


def _reduce_op_symbolic(onnx_op_name):
Expand Down