From ad5e99233f0ed5b289ca736bf303238cbbbf81a9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 4 Mar 2022 20:35:51 +0100 Subject: [PATCH 1/3] Adds support for Expand in python runtime --- .../ut_onnxrt/test_onnxrt_python_runtime_.py | 32 ++++++++++++- mlprodict/onnxrt/ops_cpu/_op_list.py | 1 + mlprodict/onnxrt/ops_cpu/op_expand.py | 46 +++++++++++++++++++ 3 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 mlprodict/onnxrt/ops_cpu/op_expand.py diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index a2b452843..702be7ecf 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -43,7 +43,7 @@ OnnxDequantizeLinear, OnnxDet, OnnxDiv, OnnxDropout, OnnxDropout_7, - OnnxEinsum, OnnxEqual, OnnxErf, OnnxExp, OnnxEyeLike, + OnnxEinsum, OnnxEqual, OnnxErf, OnnxExp, OnnxExpand, OnnxEyeLike, OnnxFlatten, OnnxFloor, OnnxGreater, OnnxGreaterOrEqual, OnnxGemm, OnnxGlobalAveragePool, OnnxIdentity, OnnxIsNaN, @@ -2132,6 +2132,36 @@ def test_onnxt_runtime_einsum(self): validate_python_inference(oinfpy, {'X': X.astype(numpy.float32), 'Y': Y.astype(numpy.float32)}) + @ignore_warnings(category=(RuntimeWarning, DeprecationWarning)) + @wraplog() + def test_onnxt_runtime_expand(self): + sh = numpy.array([2, 2, 1], dtype=numpy.int64) + onx = OnnxExpand('X', 'sh', output_names=['Y'], + op_version=TARGET_OPSET) + X = numpy.array([[1, 2], [3, -4]], dtype=numpy.float32) + model_def = onx.to_onnx({'X': X.astype(numpy.float32), 'sh': sh}, + target_opset=TARGET_OPSET) + self._check_shape_inference(OnnxExpand, model_def) + oinf = OnnxInference(model_def) + got = oinf.run({'X': X.copy(), 'sh': sh}) + self.assertEqual(list(sorted(got)), ['Y']) + exp = X * numpy.ones(sh.tolist()) + self.assertEqualArray(exp, got['Y']) + + X = numpy.array([[1.], [2.], [3.]], dtype=numpy.float32) + sh = numpy.array([2, 1, 6], dtype=numpy.int64) + exp = X * numpy.ones(sh.tolist()) + got = oinf.run({'X': X.copy(), 'sh': sh}) + self.assertEqualArray(exp, got['Y']) + + X = numpy.array([[1.], [2.], [3.]], dtype=numpy.float32) + sh = numpy.array([3, 4], dtype=numpy.int64) + exp = numpy.tile(X, 4) + got = oinf.run({'X': X.copy(), 'sh': sh}) + self.assertEqualArray(exp, got['Y']) + + python_tested.append(OnnxExpand) + @wraplog() def test_onnxt_runtime_eyelike(self): onx = OnnxEyeLike('X', k=0, output_names=['Y']) diff --git a/mlprodict/onnxrt/ops_cpu/_op_list.py b/mlprodict/onnxrt/ops_cpu/_op_list.py index f8ce868fd..815ab9bff 100644 --- a/mlprodict/onnxrt/ops_cpu/_op_list.py +++ b/mlprodict/onnxrt/ops_cpu/_op_list.py @@ -49,6 +49,7 @@ from .op_equal import Equal from .op_erf import Erf from .op_exp import Exp +from .op_expand import Expand, Expand_13 from .op_eyelike import EyeLike from .op_feature_vectorizer import FeatureVectorizer from .op_fft import FFT diff --git a/mlprodict/onnxrt/ops_cpu/op_expand.py b/mlprodict/onnxrt/ops_cpu/op_expand.py new file mode 100644 index 000000000..5312dfb05 --- /dev/null +++ b/mlprodict/onnxrt/ops_cpu/op_expand.py @@ -0,0 +1,46 @@ +# -*- encoding: utf-8 -*- +# pylint: disable=E0203,E1101,C0111 +""" +@file +@brief Runtime operator. +""" +import numpy +from onnx.defs import onnx_opset_version +from ._op import OpRun +from ..shape_object import ShapeObject + + +def common_reference_implementation(data, shape): + ones = numpy.ones(shape) + return data * ones + + +class CommonExpand(OpRun): + + def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): + OpRun.__init__( + self, onnx_node, desc=desc, + expected_attributes=expected_attributes, **options) + + def _run(self, data, shape): # pylint: disable=W0221 + return (common_reference_implementation(data, shape), ) + + def _infer_shapes(self, data, shape): # pylint: disable=W0221 + return (ShapeObject(None, dtype=data.dtype), ) + + def _infer_types(self, data, shape): # pylint: disable=W0221 + return (data, ) + + def _infer_sizes(self, *args, **kwargs): + res = self.run(*args, **kwargs) + return (dict(temp=0), ) + res + + +class Expand_13(CommonExpand): + + def __init__(self, onnx_node, desc=None, **options): + CommonExpand.__init__( + self, onnx_node, desc=desc, **options) + + +Expand = Expand_13 From 847b494c066028010caba03f72b99e9e3b28d156 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 4 Mar 2022 23:36:59 +0100 Subject: [PATCH 2/3] lint --- _unittests/ut_onnxrt/test_onnxrt_python_runtime_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index 702be7ecf..f74b2ed7f 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -2137,7 +2137,7 @@ def test_onnxt_runtime_einsum(self): def test_onnxt_runtime_expand(self): sh = numpy.array([2, 2, 1], dtype=numpy.int64) onx = OnnxExpand('X', 'sh', output_names=['Y'], - op_version=TARGET_OPSET) + op_version=TARGET_OPSET) X = numpy.array([[1, 2], [3, -4]], dtype=numpy.float32) model_def = onx.to_onnx({'X': X.astype(numpy.float32), 'sh': sh}, target_opset=TARGET_OPSET) From 0ffe66c6ff8df0c36c3bb3c61802dd3e37cabda4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 4 Mar 2022 23:41:16 +0100 Subject: [PATCH 3/3] lint --- mlprodict/onnxrt/ops_cpu/op_expand.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlprodict/onnxrt/ops_cpu/op_expand.py b/mlprodict/onnxrt/ops_cpu/op_expand.py index 5312dfb05..510a09e5e 100644 --- a/mlprodict/onnxrt/ops_cpu/op_expand.py +++ b/mlprodict/onnxrt/ops_cpu/op_expand.py @@ -5,7 +5,6 @@ @brief Runtime operator. """ import numpy -from onnx.defs import onnx_opset_version from ._op import OpRun from ..shape_object import ShapeObject