Skip to content

Commit

Permalink
[ONNX] Fix export of copy_ operator (#51938)
Browse files Browse the repository at this point in the history
copy_operator before going into onnx exporter is being decomposed into aten::expand_as and aten::index_put.
There is a scenario where inputs to copy are not of the same type, but copy op in torch does implicit casting that is not currently reflected inside onnx exporter. This PR is adding casting inside index_put symbolic in case when tensor self is not of the same type as values.
  • Loading branch information
KsenijaS committed Mar 29, 2021
1 parent c699a7e commit 78042da
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
11 changes: 11 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -1936,6 +1936,17 @@ def forward(self, x, update):
update = torch.arange(3 * 5).to(torch.float).view(3, 5)
self.run_test(IndexPutModel8(), (x, update))

class IndexPutModel9(torch.nn.Module):
def forward(self, poses):
w = 32
x = poses[:, :, 0] - (w - 1) // 2
boxes = torch.zeros([poses.shape[0], 17, 4])
boxes[:, :, 0] = x
return boxes

x = torch.zeros([2, 17, 3], dtype=torch.int64)
self.run_test(IndexPutModel9(), (x,))

@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # Ellipses followed by tensor indexing not scriptable
def test_index_put_ellipsis(self):
Expand Down
9 changes: 6 additions & 3 deletions torch/onnx/symbolic_opset11.py
Expand Up @@ -142,10 +142,13 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
values = expand(g, values, values_shape, None)
values = g.op("Reshape", values, values_shape)

dtype = self.type().scalarType()
if dtype is not None and dtype != values.type().scalarType():
values = g.op("Cast", values, to_i=sym_help.cast_pytorch_to_onnx[dtype])
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
dtype = sym_help.scalar_type_to_pytorch_type[dtype]

if accumulate:
dtype = self.type().scalarType()
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
dtype = sym_help.scalar_type_to_pytorch_type[dtype]
zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype))
result = g.op("ScatterND", zeros, index, values)
result = add(g, self, result)
Expand Down

0 comments on commit 78042da

Please sign in to comment.