diff --git a/test/onnx/expect/TestOperators.test_meshgrid_indexing.expect b/test/onnx/expect/TestOperators.test_meshgrid_indexing.expect new file mode 100644 index 000000000000..b39e370cac65 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_meshgrid_indexing.expect @@ -0,0 +1,322 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + output: "onnx::Reshape_3" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\377\377\377\377\377\377\377\377" + } + type: TENSOR + } + } + node { + input: "onnx::Reshape_1" + input: "onnx::Reshape_3" + output: "onnx::Shape_4" + name: "Reshape_1" + op_type: "Reshape" + } + node { + output: "onnx::Reshape_5" + name: "Constant_2" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\377\377\377\377\377\377\377\377" + } + type: TENSOR + } + } + node { + input: "onnx::Reshape_0" + input: "onnx::Reshape_5" + output: "onnx::Shape_6" + name: "Reshape_3" + op_type: "Reshape" + } + node { + output: "onnx::Reshape_7" + name: "Constant_4" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\377\377\377\377\377\377\377\377" + } + type: TENSOR + } + } + node { + input: "onnx::Reshape_2" + input: "onnx::Reshape_7" + output: "onnx::Shape_8" + name: "Reshape_5" + op_type: "Reshape" + } + node { + input: "onnx::Shape_4" + output: "onnx::Concat_9" + name: "Shape_6" + op_type: "Shape" + } + node { + input: "onnx::Shape_6" + output: "onnx::Concat_10" + name: "Shape_7" + op_type: "Shape" + } + node { + input: "onnx::Shape_8" + output: "onnx::Concat_11" + name: "Shape_8" + op_type: "Shape" + } + node { + input: "onnx::Concat_9" + input: "onnx::Concat_10" + input: "onnx::Concat_11" + output: "onnx::Expand_12" + name: "Concat_9" + op_type: "Concat" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + output: "onnx::Concat_13" + name: "Constant_10" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "onnx::Concat_9" + input: "onnx::Concat_13" + input: "onnx::Concat_13" + output: "onnx::Reshape_14" + name: "Concat_11" + op_type: "Concat" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "onnx::Shape_4" + input: "onnx::Reshape_14" + output: "onnx::Expand_15" + name: "Reshape_12" + op_type: "Reshape" + } + node { + input: "onnx::Expand_15" + input: "onnx::Expand_12" + output: "16" + name: "Expand_13" + op_type: "Expand" + } + node { + output: "onnx::Concat_17" + name: "Constant_14" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "onnx::Concat_17" + input: "onnx::Concat_10" + input: "onnx::Concat_17" + output: "onnx::Reshape_18" + name: "Concat_15" + op_type: "Concat" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "onnx::Shape_6" + input: "onnx::Reshape_18" + output: "onnx::Expand_19" + name: "Reshape_16" + op_type: "Reshape" + } + node { + input: "onnx::Expand_19" + input: "onnx::Expand_12" + output: "20" + name: "Expand_17" + op_type: "Expand" + } + node { + output: "onnx::Concat_21" + name: "Constant_18" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "onnx::Concat_21" + input: "onnx::Concat_21" + input: "onnx::Concat_11" + output: "onnx::Reshape_22" + name: "Concat_19" + op_type: "Concat" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "onnx::Shape_8" + input: "onnx::Reshape_22" + output: "onnx::Expand_23" + name: "Reshape_20" + op_type: "Reshape" + } + node { + input: "onnx::Expand_23" + input: "onnx::Expand_12" + output: "24" + name: "Expand_21" + op_type: "Expand" + } + name: "main_graph" + input { + name: "onnx::Reshape_0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "onnx::Reshape_1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "onnx::Reshape_2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "20" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 4 + } + dim { + dim_value: 3 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "16" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 4 + } + dim { + dim_value: 3 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "24" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 4 + } + dim { + dim_value: 3 + } + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index dc9a53a2c91f..c0861a9c32fc 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -993,6 +993,16 @@ def test_meshgrid(self): z = torch.ones(5, requires_grad=True) self.assertONNX(lambda x, y, z: torch.meshgrid(x, y, z), (x, y, z)) + def test_meshgrid_indexing(self): + x = torch.ones(3, requires_grad=True) + y = torch.zeros(4, requires_grad=True) + z = torch.ones(5, requires_grad=True) + self.assertONNX( + lambda x, y, z: torch.meshgrid(x, y, z, indexing="xy"), + (x, y, z), + opset_version=9, + ) + def test_topk(self): x = torch.arange(1.0, 6.0, requires_grad=True) k = torch.tensor(3) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 20d3dd416310..8373af9e62b3 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -7634,6 +7634,25 @@ def forward(self, x, y, z): z = torch.randn(5, requires_grad=True) self.run_test(Meshgrid(), (x, y, z)) + @skipIfUnsupportedMinOpsetVersion(8) + def test_meshgrid_indexing(self): + class Meshgrid(torch.nn.Module): + def __init__(self, indexing): + super().__init__() + self.indexing = indexing + + def forward(self, x, y, z): + output1, output2, output3 = torch.meshgrid( + x, y, z, indexing=self.indexing + ) + return output1, output2, output3 + + x = torch.randn(5, requires_grad=True) + y = torch.zeros(6, requires_grad=True) + z = torch.randn(7, requires_grad=True) + for indexing in ("xy", "ij"): + self.run_test(Meshgrid(indexing), (x, y, z)) + @skipIfUnsupportedMinOpsetVersion(8) def test_meshgrid_scalar(self): class Meshgrid(torch.nn.Module): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index c89a66ca5387..5a0eef0971e6 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -6152,13 +6152,14 @@ def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: Optional[str] = N raise errors.SymbolicValueError( f"Unsupported indexing: {indexing}", tensor_list ) + unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) if indexing == "xy": - tensor_list[0], tensor_list[1] = tensor_list[1], tensor_list[0] + unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] tensors = [ symbolic_helper._reshape_helper( g, t, g.op("Constant", value_t=torch.LongTensor([-1])) ) - for t in symbolic_helper._unpack_list(tensor_list) + for t in unpacked_tensor_list ] tensors_shape = [g.op("Shape", t) for t in tensors] out_shape = g.op("Concat", *tensors_shape, axis_i=0)