diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index a2b452843..f74b2ed7f 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -43,7 +43,7 @@ OnnxDequantizeLinear, OnnxDet, OnnxDiv, OnnxDropout, OnnxDropout_7, - OnnxEinsum, OnnxEqual, OnnxErf, OnnxExp, OnnxEyeLike, + OnnxEinsum, OnnxEqual, OnnxErf, OnnxExp, OnnxExpand, OnnxEyeLike, OnnxFlatten, OnnxFloor, OnnxGreater, OnnxGreaterOrEqual, OnnxGemm, OnnxGlobalAveragePool, OnnxIdentity, OnnxIsNaN, @@ -2132,6 +2132,36 @@ def test_onnxt_runtime_einsum(self): validate_python_inference(oinfpy, {'X': X.astype(numpy.float32), 'Y': Y.astype(numpy.float32)}) + @ignore_warnings(category=(RuntimeWarning, DeprecationWarning)) + @wraplog() + def test_onnxt_runtime_expand(self): + sh = numpy.array([2, 2, 1], dtype=numpy.int64) + onx = OnnxExpand('X', 'sh', output_names=['Y'], + op_version=TARGET_OPSET) + X = numpy.array([[1, 2], [3, -4]], dtype=numpy.float32) + model_def = onx.to_onnx({'X': X.astype(numpy.float32), 'sh': sh}, + target_opset=TARGET_OPSET) + self._check_shape_inference(OnnxExpand, model_def) + oinf = OnnxInference(model_def) + got = oinf.run({'X': X.copy(), 'sh': sh}) + self.assertEqual(list(sorted(got)), ['Y']) + exp = X * numpy.ones(sh.tolist()) + self.assertEqualArray(exp, got['Y']) + + X = numpy.array([[1.], [2.], [3.]], dtype=numpy.float32) + sh = numpy.array([2, 1, 6], dtype=numpy.int64) + exp = X * numpy.ones(sh.tolist()) + got = oinf.run({'X': X.copy(), 'sh': sh}) + self.assertEqualArray(exp, got['Y']) + + X = numpy.array([[1.], [2.], [3.]], dtype=numpy.float32) + sh = numpy.array([3, 4], dtype=numpy.int64) + exp = numpy.tile(X, 4) + got = oinf.run({'X': X.copy(), 'sh': sh}) + self.assertEqualArray(exp, got['Y']) + + python_tested.append(OnnxExpand) + @wraplog() def test_onnxt_runtime_eyelike(self): onx = OnnxEyeLike('X', k=0, output_names=['Y']) diff --git a/mlprodict/onnxrt/ops_cpu/_op_list.py b/mlprodict/onnxrt/ops_cpu/_op_list.py index f8ce868fd..815ab9bff 100644 --- a/mlprodict/onnxrt/ops_cpu/_op_list.py +++ b/mlprodict/onnxrt/ops_cpu/_op_list.py @@ -49,6 +49,7 @@ from .op_equal import Equal from .op_erf import Erf from .op_exp import Exp +from .op_expand import Expand, Expand_13 from .op_eyelike import EyeLike from .op_feature_vectorizer import FeatureVectorizer from .op_fft import FFT diff --git a/mlprodict/onnxrt/ops_cpu/op_expand.py b/mlprodict/onnxrt/ops_cpu/op_expand.py new file mode 100644 index 000000000..510a09e5e --- /dev/null +++ b/mlprodict/onnxrt/ops_cpu/op_expand.py @@ -0,0 +1,45 @@ +# -*- encoding: utf-8 -*- +# pylint: disable=E0203,E1101,C0111 +""" +@file +@brief Runtime operator. +""" +import numpy +from ._op import OpRun +from ..shape_object import ShapeObject + + +def common_reference_implementation(data, shape): + ones = numpy.ones(shape) + return data * ones + + +class CommonExpand(OpRun): + + def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): + OpRun.__init__( + self, onnx_node, desc=desc, + expected_attributes=expected_attributes, **options) + + def _run(self, data, shape): # pylint: disable=W0221 + return (common_reference_implementation(data, shape), ) + + def _infer_shapes(self, data, shape): # pylint: disable=W0221 + return (ShapeObject(None, dtype=data.dtype), ) + + def _infer_types(self, data, shape): # pylint: disable=W0221 + return (data, ) + + def _infer_sizes(self, *args, **kwargs): + res = self.run(*args, **kwargs) + return (dict(temp=0), ) + res + + +class Expand_13(CommonExpand): + + def __init__(self, onnx_node, desc=None, **options): + CommonExpand.__init__( + self, onnx_node, desc=desc, **options) + + +Expand = Expand_13