diff --git a/_doc/sphinxdoc/source/api/onnxrt.rst b/_doc/sphinxdoc/source/api/onnxrt.rst index 9297abbb8..d76aa1a66 100644 --- a/_doc/sphinxdoc/source/api/onnxrt.rst +++ b/_doc/sphinxdoc/source/api/onnxrt.rst @@ -18,7 +18,10 @@ implementated in :epkg:`Python`. The :epkg:`ONNX` model relies on the following operators :ref:`l-onnx-runtime-operators`. .. autosignature:: mlprodict.onnxrt.onnx_inference.OnnxInference - :members: + :members: run, shape_inference, check_model, run2onnx, get_profiling + +.. autosignature:: mlprodict.onnxrt.onnx_micro_inference.OnnxMicroRuntime + :members: run Python to ONNX ++++++++++++++ diff --git a/_doc/sphinxdoc/source/api/tools.rst b/_doc/sphinxdoc/source/api/tools.rst index ac51a43a6..2fa3d5749 100644 --- a/_doc/sphinxdoc/source/api/tools.rst +++ b/_doc/sphinxdoc/source/api/tools.rst @@ -83,13 +83,6 @@ Serialization .. autosignature:: mlprodict.onnx_tools.onnx2py_helper.to_bytes -Runtime -======= - -.. autosignature:: mlprodict.onnxrt.onnx_inference.OnnxInference - -.. autosignature:: mlprodict.tools.onnx_micro_runtime.OnnxMicroRuntime - Validation ++++++++++ diff --git a/_unittests/ut_tools/test_onnx_micro_runtime.py b/_unittests/ut_onnxrt/test_onnx_micro_runtime.py similarity index 96% rename from _unittests/ut_tools/test_onnx_micro_runtime.py rename to _unittests/ut_onnxrt/test_onnx_micro_runtime.py index 8fa35daa3..d82defadb 100644 --- a/_unittests/ut_tools/test_onnx_micro_runtime.py +++ b/_unittests/ut_onnxrt/test_onnx_micro_runtime.py @@ -1,121 +1,121 @@ -""" -@brief test log(time=3s) -""" -import unittest -import numpy -from pyquickhelper.pycode import ExtTestCase -from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 - OnnxAdd, OnnxTranspose, OnnxShape, OnnxPow, OnnxMatMul, OnnxGemm, - OnnxSqueeze, OnnxUnsqueeze) -from mlprodict.tools.onnx_micro_runtime import OnnxMicroRuntime - - -class TestOnnxMicroRuntime(ExtTestCase): - - opset = 15 # opset=13, 14, ... - - def test_onnx_micro_runtime(self): - opset = TestOnnxMicroRuntime.opset - dtype = numpy.float32 - x = numpy.array([1, 2, 4, 5, 5, 4]).astype( - numpy.float32).reshape((3, 2)) - cop = OnnxAdd('X', numpy.array([1], dtype=dtype), op_version=opset) - cop4 = OnnxAdd(cop, numpy.array([2], dtype=dtype), op_version=opset, - output_names=['Y']) - model_def = cop4.to_onnx({'X': x}, target_opset=opset) - rt = OnnxMicroRuntime(model_def) - out = rt.run({'X': x}) - self.assertIn('X', out) - self.assertIn('Y', out) - self.assertIn('Ad_Addcst', out) - self.assertEqual(len(out), 5) - - def test_onnx_micro_runtime_exc1(self): - self.assertRaise(lambda: OnnxMicroRuntime(None), TypeError) - - def test_onnx_micro_runtime_exc2(self): - opset = TestOnnxMicroRuntime.opset - dtype = numpy.float32 - x = numpy.array([1, 2, 4, 5, 5, 4]).astype( - numpy.float32).reshape((3, 2)) - cop = OnnxAdd('X', numpy.array([1], dtype=dtype), op_version=opset) - cop4 = OnnxPow(cop, numpy.array([2], dtype=dtype), op_version=opset, - output_names=['Y']) - model_def = cop4.to_onnx({'X': x}, target_opset=opset) - rt = OnnxMicroRuntime(model_def) - self.assertRaise(lambda: rt.run({'X': x}), NotImplementedError) - self.assertRaise(lambda: rt.run(x), TypeError) - - def test_onnx_micro_runtime_shape(self): - opset = TestOnnxMicroRuntime.opset - x = numpy.array([1, 2, 4, 5, 5, 4]).astype( - numpy.float32).reshape((3, 2)) - cop = OnnxShape('X', op_version=opset, output_names=['Y']) - model_def = cop.to_onnx({'X': x}, target_opset=opset) - rt = OnnxMicroRuntime(model_def) - out = rt.run({'X': x}) - self.assertEqual(numpy.array(x.shape, dtype=numpy.int64), out['Y']) - - def test_onnx_micro_runtime_transpose(self): - opset = TestOnnxMicroRuntime.opset - x = numpy.array([1, 2, 4, 5, 5, 4]).astype( - numpy.float32).reshape((3, 2)) - cop = OnnxTranspose('X', perm=[1, 0], op_version=opset, - output_names=['Y']) - model_def = cop.to_onnx({'X': x}, target_opset=opset) - rt = OnnxMicroRuntime(model_def) - out = rt.run({'X': x}) - self.assertEqual(x.T, out['Y']) - - def test_onnx_micro_runtime_matmul(self): - opset = TestOnnxMicroRuntime.opset - x = numpy.array([1, 2, 4, 5]).astype( - numpy.float32).reshape((2, 2)) - cop = OnnxMatMul('X', 'X', op_version=opset, - output_names=['Y']) - model_def = cop.to_onnx({'X': x}, target_opset=opset) - rt = OnnxMicroRuntime(model_def) - out = rt.run({'X': x}) - self.assertEqual(numpy.matmul(x, x), out['Y']) - - def test_onnx_micro_runtime_squeeze(self): - opset = TestOnnxMicroRuntime.opset - x = numpy.array([1, 2, 4, 5]).astype( - numpy.float32).reshape((2, 2, 1)) - cop = OnnxSqueeze('X', numpy.array([2], dtype=numpy.int64), - op_version=opset, output_names=['Y']) - model_def = cop.to_onnx({'X': x}, target_opset=opset) - rt = OnnxMicroRuntime(model_def) - out = rt.run({'X': x}) - self.assertEqual(numpy.squeeze(x), out['Y']) - - def test_onnx_micro_runtime_unsqueeze(self): - opset = TestOnnxMicroRuntime.opset - x = numpy.array([1, 2, 4, 5]).astype( - numpy.float32).reshape((2, 2)) - cop = OnnxUnsqueeze('X', numpy.array([2], dtype=numpy.int64), - op_version=opset, output_names=['Y']) - model_def = cop.to_onnx({'X': x}, target_opset=opset) - rt = OnnxMicroRuntime(model_def) - out = rt.run({'X': x}) - self.assertEqual(x.reshape((2, 2, 1)), out['Y']) - - def test_onnx_micro_runtime_gemm(self): - opset = TestOnnxMicroRuntime.opset - x = numpy.array([1, 2, 4, 5]).astype( - numpy.float32).reshape((2, 2)) - for ta in [0, 1]: - for tb in [0, 1]: - cop = OnnxGemm( - 'X', 'X', 'X', op_version=opset, alpha=1., beta=1., - output_names=['Y'], transA=ta, transB=tb) - model_def = cop.to_onnx({'X': x}, target_opset=opset) - rt = OnnxMicroRuntime(model_def) - out = rt.run({'X': x}) - xa = x.T if ta else x - xb = x.T if tb else x - self.assertEqual(numpy.matmul(xa, xb) + x, out['Y']) - - -if __name__ == "__main__": - unittest.main() +""" +@brief test log(time=3s) +""" +import unittest +import numpy +from pyquickhelper.pycode import ExtTestCase +from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 + OnnxAdd, OnnxTranspose, OnnxShape, OnnxPow, OnnxMatMul, OnnxGemm, + OnnxSqueeze, OnnxUnsqueeze) +from mlprodict.onnxrt.onnx_micro_runtime import OnnxMicroRuntime + + +class TestOnnxMicroRuntime(ExtTestCase): + + opset = 15 # opset=13, 14, ... + + def test_onnx_micro_runtime(self): + opset = TestOnnxMicroRuntime.opset + dtype = numpy.float32 + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + cop = OnnxAdd('X', numpy.array([1], dtype=dtype), op_version=opset) + cop4 = OnnxAdd(cop, numpy.array([2], dtype=dtype), op_version=opset, + output_names=['Y']) + model_def = cop4.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + self.assertIn('X', out) + self.assertIn('Y', out) + self.assertIn('Ad_Addcst', out) + self.assertEqual(len(out), 5) + + def test_onnx_micro_runtime_exc1(self): + self.assertRaise(lambda: OnnxMicroRuntime(None), TypeError) + + def test_onnx_micro_runtime_exc2(self): + opset = TestOnnxMicroRuntime.opset + dtype = numpy.float32 + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + cop = OnnxAdd('X', numpy.array([1], dtype=dtype), op_version=opset) + cop4 = OnnxPow(cop, numpy.array([2], dtype=dtype), op_version=opset, + output_names=['Y']) + model_def = cop4.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + self.assertRaise(lambda: rt.run({'X': x}), NotImplementedError) + self.assertRaise(lambda: rt.run(x), TypeError) + + def test_onnx_micro_runtime_shape(self): + opset = TestOnnxMicroRuntime.opset + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + cop = OnnxShape('X', op_version=opset, output_names=['Y']) + model_def = cop.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + self.assertEqual(numpy.array(x.shape, dtype=numpy.int64), out['Y']) + + def test_onnx_micro_runtime_transpose(self): + opset = TestOnnxMicroRuntime.opset + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + cop = OnnxTranspose('X', perm=[1, 0], op_version=opset, + output_names=['Y']) + model_def = cop.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + self.assertEqual(x.T, out['Y']) + + def test_onnx_micro_runtime_matmul(self): + opset = TestOnnxMicroRuntime.opset + x = numpy.array([1, 2, 4, 5]).astype( + numpy.float32).reshape((2, 2)) + cop = OnnxMatMul('X', 'X', op_version=opset, + output_names=['Y']) + model_def = cop.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + self.assertEqual(numpy.matmul(x, x), out['Y']) + + def test_onnx_micro_runtime_squeeze(self): + opset = TestOnnxMicroRuntime.opset + x = numpy.array([1, 2, 4, 5]).astype( + numpy.float32).reshape((2, 2, 1)) + cop = OnnxSqueeze('X', numpy.array([2], dtype=numpy.int64), + op_version=opset, output_names=['Y']) + model_def = cop.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + self.assertEqual(numpy.squeeze(x), out['Y']) + + def test_onnx_micro_runtime_unsqueeze(self): + opset = TestOnnxMicroRuntime.opset + x = numpy.array([1, 2, 4, 5]).astype( + numpy.float32).reshape((2, 2)) + cop = OnnxUnsqueeze('X', numpy.array([2], dtype=numpy.int64), + op_version=opset, output_names=['Y']) + model_def = cop.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + self.assertEqual(x.reshape((2, 2, 1)), out['Y']) + + def test_onnx_micro_runtime_gemm(self): + opset = TestOnnxMicroRuntime.opset + x = numpy.array([1, 2, 4, 5]).astype( + numpy.float32).reshape((2, 2)) + for ta in [0, 1]: + for tb in [0, 1]: + cop = OnnxGemm( + 'X', 'X', 'X', op_version=opset, alpha=1., beta=1., + output_names=['Y'], transA=ta, transB=tb) + model_def = cop.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + xa = x.T if ta else x + xb = x.T if tb else x + self.assertEqual(numpy.matmul(xa, xb) + x, out['Y']) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/onnxrt/__init__.py b/mlprodict/onnxrt/__init__.py index b3b53d01c..1ac63bc46 100644 --- a/mlprodict/onnxrt/__init__.py +++ b/mlprodict/onnxrt/__init__.py @@ -4,3 +4,5 @@ @brief Shortcut to *onnxrt*. """ from .onnx_inference import OnnxInference +from .onnx_micro_runtime import OnnxMicroRuntime + diff --git a/mlprodict/tools/onnx_micro_runtime.py b/mlprodict/onnxrt/onnx_micro_runtime.py similarity index 97% rename from mlprodict/tools/onnx_micro_runtime.py rename to mlprodict/onnxrt/onnx_micro_runtime.py index 6a2217cda..250f881af 100644 --- a/mlprodict/tools/onnx_micro_runtime.py +++ b/mlprodict/onnxrt/onnx_micro_runtime.py @@ -1,193 +1,193 @@ -""" -@file -@brief Micro runtime for ONNX. - -.. versionadded:: 0.6 -""" -import numpy -from ..onnx_tools.onnx2py_helper import _var_as_dict - - -class OnnxMicroRuntime: - """ - Implements a micro runtime for ONNX graphs. - It does not implements all the operator types. - - :param model_onnx: ONNX model - """ - - def __init__(self, model_onnx): - if not hasattr(model_onnx, 'graph'): - raise TypeError( - "model_onnx is not an ONNX graph but %r." % type(model_onnx)) - self.model_onnx = model_onnx - - def run(self, inputs): - """ - Computes the outputs of the graph. - - :param inputs: dictionary - :return: all intermediates results and output as a dictionary - """ - if not isinstance(inputs, dict): - raise TypeError( - "inputs must be a dictionary not %r." % type(inputs)) - results = inputs.copy() - - for init in self.model_onnx.graph.initializer: - name = init.name - mat = _var_as_dict(init)['value'] - results[name] = mat - - for node in self.model_onnx.graph.node: - op_type = node.op_type - inp = [results[n] for n in node.input] - meth_name = "_op_%s" % op_type.lower() - if not hasattr(self, meth_name): - raise NotImplementedError( - "OnnxMicroRuntime does not implement operator %r." % op_type) - kwargs = {} - for at in node.attribute: - var = _var_as_dict(at) - kwargs[at.name] = var['value'] - out = getattr(self, meth_name)(*inp, **kwargs) - for n, o in zip(node.output, out): - results[n] = o - - return results - - ######################## - # Runtime for operators - ######################## - - def _op_add(self, x, y): - "Runtime for operator :epkg:`Op:Add`." - return (x + y, ) - - def _op_concat(self, *args, axis=None): - "Runtime for operator :epkg:`Op:Concat`." - def _preprocess(a, axis): - if axis >= len(a.shape): - new_shape = a.shape + (1, ) * (axis + 1 - len(a.shape)) - return a.reshape(new_shape) - return a - - targs = tuple(_preprocess(a, axis) for a in args) - return (numpy.concatenate(targs, axis), ) - - def _op_gemm(self, a, b, c=None, alpha=None, beta=None, - transA=False, transB=False): - "Runtime for operator :epkg:`Op:Gemm`." - - def _gemm00(a, b, c, alpha, beta): - o = numpy.dot(a, b) * alpha - if beta != 0: - o += c * beta - return o - - def _gemm01(a, b, c, alpha, beta): - o = numpy.dot(a, b.T) * alpha - if beta != 0: - o += c * beta - return o - - def _gemm10(a, b, c, alpha, beta): - o = numpy.dot(a.T, b) * alpha - if beta != 0: - o += c * beta - return o - - def _gemm11(a, b, c, alpha, beta): - o = numpy.dot(a.T, b.T) * alpha - if beta != 0: - o += c * beta - return o - - if not isinstance(transA, (int, bool, numpy.int64)): - raise TypeError( # pragma: no cover - "Unexpected type for transA: %r." % type(transA)) - if not isinstance(transB, (int, bool, numpy.int64)): - raise TypeError( # pragma: no cover - "Unexpected type for transA: %r." % type(transB)) - if transA: - fct = _gemm11 if transB else _gemm10 - else: - fct = _gemm01 if transB else _gemm00 - return (fct(a, b, c, alpha=alpha, beta=beta), ) - - def _op_gather(self, x, indices, axis=None): - "Runtime for operator :epkg:`Op:Gather`." - if not x.flags['C_CONTIGUOUS']: - x = numpy.ascontiguousarray(x) - if not indices.flags['C_CONTIGUOUS']: - indices = indices.ascontiguousarray() - return (numpy.take(x, indices, axis=axis), ) - - def _op_identity(self, x): - "Runtime for operator :epkg:`Op:Identity`." - return (x, ) - - def _op_matmul(self, x, y): - "Runtime for operator :epkg:`Op:MatMul`." - return (numpy.matmul(x, y), ) - - def _op_max(self, *inps): - "Runtime for operator :epkg:`Op:Max`." - return (numpy.maximum(*inps), ) - - def _op_mul(self, x, y): - "Runtime for operator :epkg:`Op:Mul`." - return (x * y, ) - - def _op_reduceprod(self, data, axes=None, keepdims=None): - "Runtime for operator :epkg:`Op:ReduceProd`." - if axes is not None and not isinstance(axes, int): - if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0: - axes = int(axes) - else: - axes = tuple(axes) if len(axes) > 0 else None - return (numpy.prod(data, axis=axes, - keepdims=keepdims, - dtype=data.dtype), ) - - def _op_reducesum(self, data, axes, keepdims=None, - noop_with_empty_axes=None): - "Runtime for operator :epkg:`Op:ReduceSum`." - if axes is None and noop_with_empty_axes: - return (data, ) - if axes is not None and not isinstance(axes, int): - if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0: - axes = int(axes) - else: - axes = tuple(axes) if len(axes) > 0 else None - return (numpy.sum(data, axis=axes, - keepdims=keepdims, - dtype=data.dtype), ) - - def _op_reshape(self, x, shape): - "Runtime for operator :epkg:`Op:Reshape`." - return (x.reshape(shape), ) - - def _op_shape(self, x): - "Runtime for operator :epkg:`Op:Shape`." - return (numpy.array(list(x.shape), dtype=numpy.int64), ) - - def _op_squeeze(self, x, axes=None): - "Runtime for operator :epkg:`Op:Squeeze`." - if axes is None: - return (x, ) - if hasattr(axes, '__iter__'): - return (numpy.squeeze(x, axis=tuple(axes)), ) - return (numpy.squeeze(x, axis=axes), ) - - def _op_transpose(self, x, perm=None): - "Runtime for operator :epkg:`Op:Transpose`." - return (numpy.transpose(x, perm), ) - - def _op_unsqueeze(self, x, axes=None): - "Runtime for operator :epkg:`Op:Unsqueeze`." - if axes is None: - return (x, ) - if hasattr(axes, '__iter__'): - return (numpy.expand_dims(x, axis=tuple(axes)), ) - return (numpy.expand_dims(x, axis=axes), ) +""" +@file +@brief Micro runtime for ONNX. + +.. versionadded:: 0.6 +""" +import numpy +from ..onnx_tools.onnx2py_helper import _var_as_dict + + +class OnnxMicroRuntime: + """ + Implements a micro runtime for ONNX graphs. + It does not implements all the operator types. + + :param model_onnx: ONNX model + """ + + def __init__(self, model_onnx): + if not hasattr(model_onnx, 'graph'): + raise TypeError( + "model_onnx is not an ONNX graph but %r." % type(model_onnx)) + self.model_onnx = model_onnx + + def run(self, inputs): + """ + Computes the outputs of the graph. + + :param inputs: dictionary + :return: all intermediates results and output as a dictionary + """ + if not isinstance(inputs, dict): + raise TypeError( + "inputs must be a dictionary not %r." % type(inputs)) + results = inputs.copy() + + for init in self.model_onnx.graph.initializer: + name = init.name + mat = _var_as_dict(init)['value'] + results[name] = mat + + for node in self.model_onnx.graph.node: + op_type = node.op_type + inp = [results[n] for n in node.input] + meth_name = "_op_%s" % op_type.lower() + if not hasattr(self, meth_name): + raise NotImplementedError( + "OnnxMicroRuntime does not implement operator %r." % op_type) + kwargs = {} + for at in node.attribute: + var = _var_as_dict(at) + kwargs[at.name] = var['value'] + out = getattr(self, meth_name)(*inp, **kwargs) + for n, o in zip(node.output, out): + results[n] = o + + return results + + ######################## + # Runtime for operators + ######################## + + def _op_add(self, x, y): + "Runtime for operator :epkg:`Op:Add`." + return (x + y, ) + + def _op_concat(self, *args, axis=None): + "Runtime for operator :epkg:`Op:Concat`." + def _preprocess(a, axis): + if axis >= len(a.shape): + new_shape = a.shape + (1, ) * (axis + 1 - len(a.shape)) + return a.reshape(new_shape) + return a + + targs = tuple(_preprocess(a, axis) for a in args) + return (numpy.concatenate(targs, axis), ) + + def _op_gemm(self, a, b, c=None, alpha=None, beta=None, + transA=False, transB=False): + "Runtime for operator :epkg:`Op:Gemm`." + + def _gemm00(a, b, c, alpha, beta): + o = numpy.dot(a, b) * alpha + if beta != 0: + o += c * beta + return o + + def _gemm01(a, b, c, alpha, beta): + o = numpy.dot(a, b.T) * alpha + if beta != 0: + o += c * beta + return o + + def _gemm10(a, b, c, alpha, beta): + o = numpy.dot(a.T, b) * alpha + if beta != 0: + o += c * beta + return o + + def _gemm11(a, b, c, alpha, beta): + o = numpy.dot(a.T, b.T) * alpha + if beta != 0: + o += c * beta + return o + + if not isinstance(transA, (int, bool, numpy.int64)): + raise TypeError( # pragma: no cover + "Unexpected type for transA: %r." % type(transA)) + if not isinstance(transB, (int, bool, numpy.int64)): + raise TypeError( # pragma: no cover + "Unexpected type for transA: %r." % type(transB)) + if transA: + fct = _gemm11 if transB else _gemm10 + else: + fct = _gemm01 if transB else _gemm00 + return (fct(a, b, c, alpha=alpha, beta=beta), ) + + def _op_gather(self, x, indices, axis=None): + "Runtime for operator :epkg:`Op:Gather`." + if not x.flags['C_CONTIGUOUS']: + x = numpy.ascontiguousarray(x) + if not indices.flags['C_CONTIGUOUS']: + indices = indices.ascontiguousarray() + return (numpy.take(x, indices, axis=axis), ) + + def _op_identity(self, x): + "Runtime for operator :epkg:`Op:Identity`." + return (x, ) + + def _op_matmul(self, x, y): + "Runtime for operator :epkg:`Op:MatMul`." + return (numpy.matmul(x, y), ) + + def _op_max(self, *inps): + "Runtime for operator :epkg:`Op:Max`." + return (numpy.maximum(*inps), ) + + def _op_mul(self, x, y): + "Runtime for operator :epkg:`Op:Mul`." + return (x * y, ) + + def _op_reduceprod(self, data, axes=None, keepdims=None): + "Runtime for operator :epkg:`Op:ReduceProd`." + if axes is not None and not isinstance(axes, int): + if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0: + axes = int(axes) + else: + axes = tuple(axes) if len(axes) > 0 else None + return (numpy.prod(data, axis=axes, + keepdims=keepdims, + dtype=data.dtype), ) + + def _op_reducesum(self, data, axes, keepdims=None, + noop_with_empty_axes=None): + "Runtime for operator :epkg:`Op:ReduceSum`." + if axes is None and noop_with_empty_axes: + return (data, ) + if axes is not None and not isinstance(axes, int): + if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0: + axes = int(axes) + else: + axes = tuple(axes) if len(axes) > 0 else None + return (numpy.sum(data, axis=axes, + keepdims=keepdims, + dtype=data.dtype), ) + + def _op_reshape(self, x, shape): + "Runtime for operator :epkg:`Op:Reshape`." + return (x.reshape(shape), ) + + def _op_shape(self, x): + "Runtime for operator :epkg:`Op:Shape`." + return (numpy.array(list(x.shape), dtype=numpy.int64), ) + + def _op_squeeze(self, x, axes=None): + "Runtime for operator :epkg:`Op:Squeeze`." + if axes is None: + return (x, ) + if hasattr(axes, '__iter__'): + return (numpy.squeeze(x, axis=tuple(axes)), ) + return (numpy.squeeze(x, axis=axes), ) + + def _op_transpose(self, x, perm=None): + "Runtime for operator :epkg:`Op:Transpose`." + return (numpy.transpose(x, perm), ) + + def _op_unsqueeze(self, x, axes=None): + "Runtime for operator :epkg:`Op:Unsqueeze`." + if axes is None: + return (x, ) + if hasattr(axes, '__iter__'): + return (numpy.expand_dims(x, axis=tuple(axes)), ) + return (numpy.expand_dims(x, axis=axes), ) diff --git a/mlprodict/testing/einsum/einsum_fct.py b/mlprodict/testing/einsum/einsum_fct.py index 1830146de..589698cf7 100644 --- a/mlprodict/testing/einsum/einsum_fct.py +++ b/mlprodict/testing/einsum/einsum_fct.py @@ -10,7 +10,7 @@ from onnx import helper from skl2onnx.common.data_types import FloatTensorType from ...onnx_tools.onnx2py_helper import guess_proto_dtype -from ...tools.onnx_micro_runtime import OnnxMicroRuntime +from ...onnxrt.onnx_micro_runtime import OnnxMicroRuntime from ...tools.asv_options_helper import ( get_opset_number_from_onnx, get_ir_version_from_onnx) from .einsum_impl import decompose_einsum_equation, apply_einsum_sequence