Skip to content

Commit

Permalink
Add TopK to ONNX Frontend (apache#5441)
Browse files Browse the repository at this point in the history
* Add TopK to ONNX Frontend

* respond to review comments
  • Loading branch information
mbrookhart authored and Trevor Morris committed Jun 8, 2020
1 parent bbd2f04 commit 8b9691e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,22 @@ def _impl_v9(cls, inputs, attr, params):
output = AttrCvt(op_name='argwhere')(inputs, attr, params)
return _op.transpose(output, axes=(1, 0))

class TopK(OnnxOpConverter):
"""Operator converter for TopK
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if len(inputs) != 2:
raise ValueError("Expect 2 input only")
axis = attr.get("axis", -1)
largest = attr.get("largest", 1)

if largest == 0:
raise ValueError("TVM only supports finding TopK largest elements")

K = int(infer_value(inputs[1], params).asnumpy()[0])

return _op.topk(inputs[0], k=K, axis=axis)

# compatible operators that do NOT require any conversion.
_identity_list = []
Expand Down Expand Up @@ -1573,8 +1589,11 @@ def _get_convert_map(opset):
'ReduceProd': ReduceProd.get_converter(opset),
# 'ReduceProd'
# 'ReduceLogSumExp'

#defs/sorting
'ArgMax': ArgMax.get_converter(opset),
'ArgMin': ArgMin.get_converter(opset),
'TopK': TopK.get_converter(opset),

# defs/tensor
'Cast': Cast.get_converter(opset),
Expand Down
38 changes: 38 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,6 +2330,43 @@ def verify_nonzero(indata, outdata, dtype):
result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 2, 2], [0, 1, 0, 1]]
verify_nonzero(input_data, result, dtype=np.int64)

def test_topk():
def verify_topk(input_dims, K, axis=-1):
output_dims = list(input_dims)
output_dims[axis] = K

node = helper.make_node('TopK',
inputs=['X', 'K'],
outputs=['Values', 'Indicies'],
axis=axis)

graph = helper.make_graph([node],
"topk_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
helper.make_tensor_value_info("K", TensorProto.INT64, [1,])],
initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])],
outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims),
helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)])

model = helper.make_model(graph, producer_name='topk_test')

indata = np.random.uniform(-10, 10, input_dims).astype(np.float32)
onnx_out = get_onnxruntime_output(model, [indata, k])

for target, ctx in [('llvm', tvm.cpu())]:
tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims],
output_dtype=['float32', 'int64'])
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)

for n in [12, 32]:
for shape in [[n], [n, n], [n, n, n]]:
for k in [1, 5, 10]:
verify_topk(shape, k)

verify_topk([n, n, n], 5, 0)
verify_topk([n, n, n], 5, 1)
verify_topk([n, n, n], 5, 2)


if __name__ == '__main__':
test_flatten()
Expand Down Expand Up @@ -2392,3 +2429,4 @@ def verify_nonzero(indata, outdata, dtype):
test_lstm()
test_resize()
test_nonzero()
test_topk()

0 comments on commit 8b9691e

Please sign in to comment.