Skip to content

Commit

Permalink
[ONNX] Add index_put operator (apache#8894)
Browse files Browse the repository at this point in the history
* onnx:add index_put

* reformat code

* add parametrize_targets

* change slice to onnx_index instance

* modify test_forward
  • Loading branch information
liaojianjin authored and ylc committed Jan 13, 2022
1 parent 53549a0 commit 0710240
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
42 changes: 42 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,6 +3021,7 @@ def _op_dispatch(cls, operator, inputs, attr, params):
op_map = {
"size": cls._size,
"arange": cls._arange,
"index_put": cls._index_put,
"reshape": cls._reshape,
"embedding_bag": cls._embedding_bag,
}
Expand All @@ -3040,6 +3041,47 @@ def _size(cls, inputs, attr, params):
def _arange(cls, inputs, attr, params):
return _op.arange(inputs[0], inputs[1], inputs[2], dtype="int64")

@classmethod
def _check_index(cls, indices, values):
def unfolding_indices(indices, values):
n = len(indices)
flatten_indices = []
slices_size = []
for index in indices:
flatten_indices.append(_op.reshape(index, _op.const([-1])))
slices_size.append(infer_shape(flatten_indices[-1])[0])
repeat_size = [1]
tile_size = [1]
for i in range(1, n):
repeat_size.append(slices_size[-i] * repeat_size[-1])
tile_size.append(slices_size[i - 1] * tile_size[-1])
repeat_size.reverse()
unflod_slices = []
for i in range(n):
unflod_slices.append(
fold_constant(
_op.repeat(_op.tile(flatten_indices[i], (tile_size[i],)), repeat_size[i], 0)
)
)
return unflod_slices, _op.reshape(values, _op.const([-1]))

values_shape = infer_shape(values)
if len(values_shape) != 1:
return unfolding_indices(indices, values)
return indices, values

@classmethod
def _index_put(cls, inputs, attr, params):
in_tensor = inputs[0]
indices, values = cls._check_index(inputs[1 : len(inputs) - 2], inputs[len(inputs) - 2])
accumulate = inputs[len(inputs) - 1].data.asnumpy() != 0
if not accumulate:
mode = "update"
else:
mode = "add"
index_tensor = _op.stack(indices, axis=0)
return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)

@classmethod
def _reshape(cls, inputs, attr, params):
return _op.reshape(inputs[0], inputs[1])
Expand Down
76 changes: 76 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5040,6 +5040,81 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None
verify_embedding_bag(32, 2, [3, 3])


@tvm.testing.parametrize_targets
def test_index_put(target, dev):
class _index_put_model(torch.nn.Module):
def __init__(self, indices, values, accumulate):
super(_index_put_model, self).__init__()
self.indices = indices
self.values = values
self.accumulate = accumulate

def forward(self, x):
return x.index_put(self.indices, self.values, self.accumulate)

def _convert_to_onnx(model, dummy_data):
file_name = "{}.onnx".format("aten_model")
torch.onnx.export(
model,
dummy_data,
file_name,
export_params=True,
verbose=False,
opset_version=11,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
onnx_model = onnx.load(file_name)
return onnx_model

def verify_index_put(data_shape, indices, accumulate):
dummy_data = torch.ones(data_shape)
tvm_inputs = [dummy_data.numpy()]
values = torch.rand(indices[0].size())
model = _index_put_model(indices, values, accumulate)
onnx_model = _convert_to_onnx(model, dummy_data)
torch_out = model(dummy_data)

tvm_out = get_tvm_output_with_vm(
onnx_model, tvm_inputs, target, dev, freeze_params=True, convert_to_static=True
)
tvm.testing.assert_allclose(torch_out.numpy(), tvm_out)

shape = (3, 5)
xidx = torch.tensor([0, 1, 2, 2])
yidx = torch.tensor([0, 1, 3, 4])
verify_index_put(shape, [xidx, yidx], True)

shape = (3, 5, 3)
xidx = torch.tensor([0, 1, 2, 2, 0])
yidx = torch.tensor([0, 1, 3, 4, 0])
zidx = torch.tensor([0, 1, 1, 2, 0])
verify_index_put(shape, [xidx, yidx, zidx], False)

def verify_index_put_slice(data_shape, value_shape, accumulate):
dummy_data = torch.ones(data_shape)
tvm_inputs = [dummy_data.numpy()]
indices = []
index_shape = [1] * len(value_shape)
index_shape[0] = -1
for i in range(len(value_shape)):
indices.append(torch.arange(0, value_shape[i]).reshape(tuple(index_shape)))
index_shape.pop()
values = torch.rand(value_shape)

model = _index_put_model(indices, values, accumulate)
onnx_model = _convert_to_onnx(model, dummy_data)
torch_out = model(dummy_data)

tvm_out = get_tvm_output_with_vm(
onnx_model, tvm_inputs, target, dev, freeze_params=True, convert_to_static=True
)
tvm.testing.assert_allclose(torch_out.numpy(), tvm_out)

verify_index_put_slice((3, 3), (2, 2), False)
verify_index_put_slice((2, 3, 4), (1, 2, 3), True)
verify_index_put_slice((2, 3, 4, 5), (1, 2, 3, 1), False)


@tvm.testing.parametrize_targets
def test_reverse_sequence(target, dev):
def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis):
Expand Down Expand Up @@ -5616,6 +5691,7 @@ def repeat(N, D):
test_cumsum()
test_wrong_input()
test_aten()
test_index_put()
test_reverse_sequence()
test_eyelike()
test_qlinearconv()
Expand Down

0 comments on commit 0710240

Please sign in to comment.