Skip to content

Commit

Permalink
[ONNX] Fix indexing issue of meshgrid op (#109350)
Browse files Browse the repository at this point in the history
Should unpack tensor_list before swapping the elements for indexing 'xy'.

Pull Request resolved: #109350
Approved by: https://github.com/thiagocrepaldi
  • Loading branch information
CYuxian authored and pytorchmergebot committed Sep 15, 2023
1 parent 4c208c1 commit 504dcea
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 2 deletions.
322 changes: 322 additions & 0 deletions test/onnx/expect/TestOperators.test_meshgrid_indexing.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "CURRENT_VERSION"
graph {
node {
output: "onnx::Reshape_3"
name: "Constant_0"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\377\377\377\377\377\377\377\377"
}
type: TENSOR
}
}
node {
input: "onnx::Reshape_1"
input: "onnx::Reshape_3"
output: "onnx::Shape_4"
name: "Reshape_1"
op_type: "Reshape"
}
node {
output: "onnx::Reshape_5"
name: "Constant_2"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\377\377\377\377\377\377\377\377"
}
type: TENSOR
}
}
node {
input: "onnx::Reshape_0"
input: "onnx::Reshape_5"
output: "onnx::Shape_6"
name: "Reshape_3"
op_type: "Reshape"
}
node {
output: "onnx::Reshape_7"
name: "Constant_4"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\377\377\377\377\377\377\377\377"
}
type: TENSOR
}
}
node {
input: "onnx::Reshape_2"
input: "onnx::Reshape_7"
output: "onnx::Shape_8"
name: "Reshape_5"
op_type: "Reshape"
}
node {
input: "onnx::Shape_4"
output: "onnx::Concat_9"
name: "Shape_6"
op_type: "Shape"
}
node {
input: "onnx::Shape_6"
output: "onnx::Concat_10"
name: "Shape_7"
op_type: "Shape"
}
node {
input: "onnx::Shape_8"
output: "onnx::Concat_11"
name: "Shape_8"
op_type: "Shape"
}
node {
input: "onnx::Concat_9"
input: "onnx::Concat_10"
input: "onnx::Concat_11"
output: "onnx::Expand_12"
name: "Concat_9"
op_type: "Concat"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
output: "onnx::Concat_13"
name: "Constant_10"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "onnx::Concat_9"
input: "onnx::Concat_13"
input: "onnx::Concat_13"
output: "onnx::Reshape_14"
name: "Concat_11"
op_type: "Concat"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
input: "onnx::Shape_4"
input: "onnx::Reshape_14"
output: "onnx::Expand_15"
name: "Reshape_12"
op_type: "Reshape"
}
node {
input: "onnx::Expand_15"
input: "onnx::Expand_12"
output: "16"
name: "Expand_13"
op_type: "Expand"
}
node {
output: "onnx::Concat_17"
name: "Constant_14"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "onnx::Concat_17"
input: "onnx::Concat_10"
input: "onnx::Concat_17"
output: "onnx::Reshape_18"
name: "Concat_15"
op_type: "Concat"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
input: "onnx::Shape_6"
input: "onnx::Reshape_18"
output: "onnx::Expand_19"
name: "Reshape_16"
op_type: "Reshape"
}
node {
input: "onnx::Expand_19"
input: "onnx::Expand_12"
output: "20"
name: "Expand_17"
op_type: "Expand"
}
node {
output: "onnx::Concat_21"
name: "Constant_18"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "onnx::Concat_21"
input: "onnx::Concat_21"
input: "onnx::Concat_11"
output: "onnx::Reshape_22"
name: "Concat_19"
op_type: "Concat"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
input: "onnx::Shape_8"
input: "onnx::Reshape_22"
output: "onnx::Expand_23"
name: "Reshape_20"
op_type: "Reshape"
}
node {
input: "onnx::Expand_23"
input: "onnx::Expand_12"
output: "24"
name: "Expand_21"
op_type: "Expand"
}
name: "main_graph"
input {
name: "onnx::Reshape_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
input {
name: "onnx::Reshape_1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
}
}
}
}
input {
name: "onnx::Reshape_2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
output {
name: "20"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "16"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "24"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 9
}
10 changes: 10 additions & 0 deletions test/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,16 @@ def test_meshgrid(self):
z = torch.ones(5, requires_grad=True)
self.assertONNX(lambda x, y, z: torch.meshgrid(x, y, z), (x, y, z))

def test_meshgrid_indexing(self):
x = torch.ones(3, requires_grad=True)
y = torch.zeros(4, requires_grad=True)
z = torch.ones(5, requires_grad=True)
self.assertONNX(
lambda x, y, z: torch.meshgrid(x, y, z, indexing="xy"),
(x, y, z),
opset_version=9,
)

def test_topk(self):
x = torch.arange(1.0, 6.0, requires_grad=True)
k = torch.tensor(3)
Expand Down
19 changes: 19 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7634,6 +7634,25 @@ def forward(self, x, y, z):
z = torch.randn(5, requires_grad=True)
self.run_test(Meshgrid(), (x, y, z))

@skipIfUnsupportedMinOpsetVersion(8)
def test_meshgrid_indexing(self):
class Meshgrid(torch.nn.Module):
def __init__(self, indexing):
super().__init__()
self.indexing = indexing

def forward(self, x, y, z):
output1, output2, output3 = torch.meshgrid(
x, y, z, indexing=self.indexing
)
return output1, output2, output3

x = torch.randn(5, requires_grad=True)
y = torch.zeros(6, requires_grad=True)
z = torch.randn(7, requires_grad=True)
for indexing in ("xy", "ij"):
self.run_test(Meshgrid(indexing), (x, y, z))

@skipIfUnsupportedMinOpsetVersion(8)
def test_meshgrid_scalar(self):
class Meshgrid(torch.nn.Module):
Expand Down
5 changes: 3 additions & 2 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -6152,13 +6152,14 @@ def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: Optional[str] = N
raise errors.SymbolicValueError(
f"Unsupported indexing: {indexing}", tensor_list
)
unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list)
if indexing == "xy":
tensor_list[0], tensor_list[1] = tensor_list[1], tensor_list[0]
unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1]
tensors = [
symbolic_helper._reshape_helper(
g, t, g.op("Constant", value_t=torch.LongTensor([-1]))
)
for t in symbolic_helper._unpack_list(tensor_list)
for t in unpacked_tensor_list
]
tensors_shape = [g.op("Shape", t) for t in tensors]
out_shape = g.op("Concat", *tensors_shape, axis_i=0)
Expand Down

0 comments on commit 504dcea

Please sign in to comment.