Skip to content

Commit

Permalink
Slice op support apache#1297
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Jul 2, 2018
1 parent 9aabd79 commit 2ef13b7
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
44 changes: 43 additions & 1 deletion nnvm/python/nnvm/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,48 @@ def _impl_v1(cls, inputs, attr, params):
inputs[0] = _sym.expand_dims(inputs[0], axis=axes, num_newaxis=1)
return inputs[0]


class Slice(OnnxOpConverter):
""" Operator converter for Slice.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if isinstance(attr['starts'], int):
attr['starts'] = (attr['starts'],)
attr['ends'] = (attr['ends'],)

try:
# Update the starts and ends according to axes if required.
if isinstance(attr['axes'], int):
attr['axes'] = (attr['axes'],)

if (max(attr['axes']) + 1) != len(attr['axes']):
new_axes = []
new_starts = []
new_ends = []
pop_index = 0
for i in range(max(attr['axes']) + 1):
if i in attr['axes']:
new_axes.append(i)
new_starts.append(attr['starts'][pop_index])
new_ends.append(attr['ends'][pop_index])
pop_index += 1
else:
new_axes.append(i)
new_starts.append(0)
new_ends.append(10000) # very big number
attr['axes'] = new_axes
attr['starts'] = new_starts
attr['ends'] = new_ends
except KeyError:
pass

return AttrCvt(op_name='strided_slice',
transforms={'starts': 'begin',
'ends': 'end'},
ignores=['axes'])(inputs, attr)


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -550,7 +592,7 @@ def _get_convert_map(opset):
'Reshape': Reshape.get_converter(opset),
'Concat': Renamer('concatenate'),
'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
# 'Slice'
'Slice': Slice.get_converter(opset),
'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
# 'Gather'
'Squeeze': Renamer('squeeze'),
Expand Down
27 changes: 27 additions & 0 deletions nnvm/tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,32 @@ def test_unsqueeze():

np.testing.assert_allclose(out_shape, tvm_out.shape)

def _test_slice_iteration(indata, outdata, starts, ends, axes=None):
if axes:
y = helper.make_node("Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends)
else:
y = helper.make_node("Slice", ['in'], ['out'], starts=starts, ends=ends)

graph = helper.make_graph([y],
'slice_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))])

model = helper.make_model(graph, producer_name='slice_test')

for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')

np.testing.assert_allclose(outdata, tvm_out)

def test_slice():
x = np.random.randn(20, 10, 5).astype(np.float32)
_test_slice_iteration(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
_test_slice_iteration(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
_test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1))
#_test_slice_iteration(x, x[:, 1000:1000], (1000), (1000), (1))

if __name__ == '__main__':
# verify_super_resolution_example()
# verify_squeezenet1_1()
Expand All @@ -153,3 +179,4 @@ def test_unsqueeze():
test_reshape_like()
test_squeeze()
test_unsqueeze()
test_slice()

0 comments on commit 2ef13b7

Please sign in to comment.