From c19db671363ce29d0f2957b73f0e640999f9a394 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 27 Feb 2021 01:58:49 +0100 Subject: [PATCH 1/2] Test against onnxruntime --- _doc/sphinxdoc/source/tutorial/onnx_numpy.rst | 2 +- _unittests/ut_npy/test_numpy_onnx_pyrt.py | 45 +++++++++++++------ mlprodict/npy/numpy_onnx_pyrt.py | 6 +-- mlprodict/npy/onnx_numpy_annotation.py | 26 ++++++++++- mlprodict/npy/onnx_numpy_wrapper.py | 3 +- mlprodict/onnxrt/ops_cpu/op_cast.py | 4 ++ mlprodict/onnxrt/shape_object.py | 1 + 7 files changed, 68 insertions(+), 19 deletions(-) diff --git a/_doc/sphinxdoc/source/tutorial/onnx_numpy.rst b/_doc/sphinxdoc/source/tutorial/onnx_numpy.rst index d875025ce..e6a1dd8f1 100644 --- a/_doc/sphinxdoc/source/tutorial/onnx_numpy.rst +++ b/_doc/sphinxdoc/source/tutorial/onnx_numpy.rst @@ -44,7 +44,7 @@ the first examples of `sklearn-onnx tutorial`. # Conversion to ONNX try: onx = to_onnx(log_scale_transformer, X) - except RuntimeError as e: + except (RuntimeError, TypeError) as e: print(e) The first step is a `FunctionTransformer` with a custom function diff --git a/_unittests/ut_npy/test_numpy_onnx_pyrt.py b/_unittests/ut_npy/test_numpy_onnx_pyrt.py index bc44531ed..75f79393d 100644 --- a/_unittests/ut_npy/test_numpy_onnx_pyrt.py +++ b/_unittests/ut_npy/test_numpy_onnx_pyrt.py @@ -6,21 +6,39 @@ import numpy import scipy.special as sp from pyquickhelper.pycode import ExtTestCase +from mlprodict.onnxrt import OnnxInference import mlprodict.npy.numpy_onnx_pyrt as nxnpy +try: + numpy_bool = numpy.bool_ +except AttributeError: + numpy_bool = bool + + class TestNumpyOnnxFunction(ExtTestCase): - def common_test1(self, x, npfct, nxfct, dtype, **kwargs): + def common_test1(self, x, npfct, nxfct, dtype, dtype_out=None, **kwargs): xt = x.astype(dtype) - if kwargs is None or len(kwargs) == 0: + if dtype_out is None and (kwargs is None or len(kwargs) == 0): expected = npfct(xt) got = nxfct[dtype](xt) + compiled = nxfct[dtype].compiled else: expected = npfct(xt, **kwargs) kwargs['dtype_onnx'] = dtype + if dtype_out is not None: + kwargs['dtype_onnx_out'] = dtype_out got = nxfct[kwargs](xt) + compiled = nxfct[kwargs].compiled self.assertEqualArray(expected, got) + onx = compiled.onnx_ + rt2 = OnnxInference(onx, runtime="onnxruntime1") + inputs = rt2.input_names + outputs = rt2.output_names + data = {inputs[0]: xt} + got2 = rt2.run(data)[outputs[0]] + self.assertEqualArray(expected, got2, decimal=6) def test_abs_float32(self): x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) @@ -42,29 +60,29 @@ def test_amax_float32(self): self.common_test1(x, numpy.amax, nxnpy.amax, numpy.float32, **kw) - def test_argmax_float32(self): - kwargs = [{'axis': 0}, {'axis': 1}] + def test_amin_float32(self): + kwargs = [{'axis': 0}, {}, {'axis': 1}] for kw in kwargs: with self.subTest(kw=kw): x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) - self.common_test1(x, numpy.argmax, nxnpy.argmax, + self.common_test1(x, numpy.amin, nxnpy.amin, numpy.float32, **kw) - def test_argmin_float32(self): + def test_argmax_float32(self): kwargs = [{'axis': 0}, {'axis': 1}] for kw in kwargs: with self.subTest(kw=kw): x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) - self.common_test1(x, numpy.argmin, nxnpy.argmin, - numpy.float32, **kw) + self.common_test1(x, numpy.argmax, nxnpy.argmax, + numpy.float32, dtype_out=numpy.int64, **kw) - def test_amin_float32(self): - kwargs = [{'axis': 0}, {}, {'axis': 1}] + def test_argmin_float32(self): + kwargs = [{'axis': 0}, {'axis': 1}] for kw in kwargs: with self.subTest(kw=kw): x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) - self.common_test1(x, numpy.amin, nxnpy.amin, - numpy.float32, **kw) + self.common_test1(x, numpy.argmin, nxnpy.argmin, + numpy.float32, dtype_out=numpy.int64, **kw) def test_asin_float32(self): x = numpy.array([[0.5, 0.1], [-0.5, -0.1]], dtype=numpy.float32) @@ -97,7 +115,8 @@ def test_exp_float32(self): def test_isnan_float32(self): x = numpy.array([[6.1, 5], [3.5, numpy.nan]], dtype=numpy.float32) - self.common_test1(x, numpy.isnan, nxnpy.isnan, numpy.float32) + self.common_test1(x, numpy.isnan, nxnpy.isnan, numpy.float32, + dtype_out=numpy_bool) def test_log_float32(self): x = numpy.array([[6.1, 5], [3.5, 7.8]], dtype=numpy.float32) diff --git a/mlprodict/npy/numpy_onnx_pyrt.py b/mlprodict/npy/numpy_onnx_pyrt.py index a2e159fe2..3db2ab6ec 100644 --- a/mlprodict/npy/numpy_onnx_pyrt.py +++ b/mlprodict/npy/numpy_onnx_pyrt.py @@ -71,13 +71,13 @@ def amin(x, axis=None, keepdims=0): return nx_min(x, axis=axis, keepdims=keepdims) -@onnxnumpy_np(signature=NDArraySameType("all")) +@onnxnumpy_np(signature=NDArraySameType("all_int")) def argmax(x, axis=None, keepdims=0): "argmax" return nx_argmax(x, axis=axis, keepdims=keepdims) -@onnxnumpy_np(signature=NDArraySameType("all")) +@onnxnumpy_np(signature=NDArraySameType("all_int")) def argmin(x, axis=None, keepdims=0): "argmin" return nx_argmin(x, axis=axis, keepdims=keepdims) @@ -131,7 +131,7 @@ def exp(x): return nx_exp(x) -@onnxnumpy_np(signature=NDArraySameTypeSameShape("all")) +@onnxnumpy_np(signature=NDArraySameTypeSameShape("all_bool")) def isnan(x): "isnan" return nx_isnan(x) diff --git a/mlprodict/npy/onnx_numpy_annotation.py b/mlprodict/npy/onnx_numpy_annotation.py index f8339dbf6..61cad3a89 100644 --- a/mlprodict/npy/onnx_numpy_annotation.py +++ b/mlprodict/npy/onnx_numpy_annotation.py @@ -9,6 +9,11 @@ from typing import TypeVar, Generic import numpy +try: + numpy_bool = numpy.bool_ +except AttributeError: + numpy_bool = bool + Shape = TypeVar("Shape") DType = TypeVar("DType") @@ -56,13 +61,23 @@ def __class_getitem__(cls, params): class _NDArrayAlias: def __init__(self, dtypes=None): self.dtypes = dtypes + self.dtypes_out = dtypes if isinstance(self.dtypes, str): if self.dtypes == "all": self.dtypes = all_dtypes + self.dtypes_out = self.dtypes + elif self.dtypes == "all_int": + self.dtypes = all_dtypes + self.dtypes_out = (numpy.int64, ) + elif self.dtypes == "all_bool": + self.dtypes = all_dtypes + self.dtypes_out = (numpy_bool, ) elif self.dtypes == "floats": self.dtypes = (numpy.float32, numpy.float64) + self.dtypes_out = self.dtypes elif self.dtypes == "ints": self.dtypes = (numpy.int32, numpy.int64) + self.dtypes_out = self.dtypes else: raise ValueError( "Unexpected shortcut for dtype %r." % self.dtypes) @@ -103,11 +118,20 @@ def _possible_names(): else: dtype = version + if isinstance(dtype, tuple): + dtype, dtype_out = dtype + else: + dtype_out = dtype if dtype not in self.dtypes: raise TypeError( "Unexpected version %r, it should be in %r." % ( version, self.dtypes)) + if dtype_out not in self.dtypes_out: + raise TypeError( + "Unexpected version %r, it should be in %r." % ( + version, self.dtypes_out)) onnx_type = self._to_onnx_dtype(dtype, None) + onnx_type_out = self._to_onnx_dtype(dtype_out, None) inputs = [(a, onnx_type) for a in args] names_in = set(inp[0] for inp in inputs) name_out = None @@ -115,7 +139,7 @@ def _possible_names(): if name not in names_in: name_out = name break - outputs = [(name_out, onnx_type)] + outputs = [(name_out, onnx_type_out)] return inputs, outputs def shape_calculator(self, dims): diff --git a/mlprodict/npy/onnx_numpy_wrapper.py b/mlprodict/npy/onnx_numpy_wrapper.py index 2f27cba03..e232e59e3 100644 --- a/mlprodict/npy/onnx_numpy_wrapper.py +++ b/mlprodict/npy/onnx_numpy_wrapper.py @@ -98,7 +98,8 @@ def __getitem__(self, dtype): raise RuntimeError( "Signature does not have any arguments, use directly dtypes.") others = tuple(dtype.get(k, self.kwargs[k]) for k in self.kwargs) - key = (dtype['dtype_onnx'], ) + others + key = ((dtype['dtype_onnx'], + dtype.get('dtype_onnx_out', dtype['dtype_onnx'])), ) + others self._populate(key) elif dtype not in self.signed_compiled: self._populate(dtype) diff --git a/mlprodict/onnxrt/ops_cpu/op_cast.py b/mlprodict/onnxrt/ops_cpu/op_cast.py index 390afda19..377015fcf 100644 --- a/mlprodict/onnxrt/ops_cpu/op_cast.py +++ b/mlprodict/onnxrt/ops_cpu/op_cast.py @@ -36,6 +36,10 @@ def __init__(self, onnx_node, desc=None, **options): self._dtype = numpy.str elif self.to == TensorProto.FLOAT16: # pylint: disable=E1101 self._dtype = numpy.float16 + elif self.to == TensorProto.COMPLEX64: # pylint: disable=E1101 + self._dtype = numpy.complex64 + elif self.to == TensorProto.COMPLEX128: # pylint: disable=E1101 + self._dtype = numpy.complex128 else: raise ValueError( # pragma: no cover "Unexpected value for to='{}'.".format( diff --git a/mlprodict/onnxrt/shape_object.py b/mlprodict/onnxrt/shape_object.py index 51d6c59b8..37cddad52 100644 --- a/mlprodict/onnxrt/shape_object.py +++ b/mlprodict/onnxrt/shape_object.py @@ -514,6 +514,7 @@ def __init__(self, shape, dtype=None, use_n1=False, name=None): elif self._dtype not in { numpy.float32, numpy.float64, numpy.int32, numpy.int64, numpy.str, numpy.bool, numpy.float16, None, + numpy.complex64, numpy.complex128, 'map'}: raise ValueError( # pragma: no cover "dtype has an unexpected value: '{}'.".format(self._dtype)) From 15deedb7c9ec8e9d78304303addeeb3b003781be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 27 Feb 2021 12:32:44 +0100 Subject: [PATCH 2/2] complete test against ort --- _unittests/ut_npy/test_onnx_variable.py | 86 ++-- _unittests/ut_npy/test_onnx_variable_ort.py | 499 ++++++++++++++++++++ mlprodict/npy/onnx_numpy_annotation.py | 2 + mlprodict/npy/onnx_variable.py | 2 +- 4 files changed, 545 insertions(+), 44 deletions(-) create mode 100644 _unittests/ut_npy/test_onnx_variable_ort.py diff --git a/_unittests/ut_npy/test_onnx_variable.py b/_unittests/ut_npy/test_onnx_variable.py index 77ce67c1e..59f19daff 100644 --- a/_unittests/ut_npy/test_onnx_variable.py +++ b/_unittests/ut_npy/test_onnx_variable.py @@ -105,15 +105,15 @@ def test_abs_matmul(x: NDArray[Any, numpy.float32], @onnxnumpy_default def test_abs_div(x: NDArray[Any, numpy.float32], ) -> NDArray[Any, numpy.float32]: - "onnx numpy addition" + "onnx numpy division" return nxnp.abs(x) / x @onnxnumpy_default def test_abs_idiv(x: NDArray[Any, numpy.float32], - ) -> NDArray[Any, numpy.float32]: - "onnx numpy addition" - return nxnp.abs(x) // x + ) -> NDArray[Any, numpy.int64]: + "onnx numpy int division" + return nxnp.abs(x).astype(numpy.int64) // x.astype(numpy.int64) @onnxnumpy_default @@ -145,14 +145,14 @@ def test_abs_less(x: NDArray[Any, numpy.float32], @onnxnumpy_default -def test_abs_and(x: NDArray[Any, numpy_bool], +def test_abs_and(x: NDArray[Any, numpy.float32], ) -> NDArray[Any, numpy_bool]: "onnx numpy and" return (nxnp.abs(x) < x) and (nxnp.abs(x) < numpy.float32(0)) @onnxnumpy_default -def test_abs_or(x: NDArray[Any, numpy_bool], +def test_abs_or(x: NDArray[Any, numpy.float32], ) -> NDArray[Any, numpy_bool]: "onnx numpy or" return (nxnp.abs(x) < x) or (nxnp.abs(x) < numpy.float32(0)) @@ -245,7 +245,7 @@ def test_abs_filter(x: NDArray[Any, numpy.float32], @onnxnumpy_default def test_abs_set2(x: NDArray[Any, numpy.float32], - ) -> NDArray[Any, numpy.bool]: + ) -> NDArray[Any, numpy.float32]: "onnx numpy set" temp = nxnp.abs(x).copy() temp[:2, 0] = numpy.float32(-1) @@ -254,7 +254,7 @@ def test_abs_set2(x: NDArray[Any, numpy.float32], @onnxnumpy_default def test_abs_set3(x: NDArray[Any, numpy.float32], - ) -> NDArray[Any, numpy.bool]: + ) -> NDArray[Any, numpy.float32]: "onnx numpy set" temp = nxnp.abs(x).copy() temp[:2, :1] = numpy.array([[-1.5, -1.5]], dtype=numpy.float32).T @@ -290,7 +290,7 @@ def test_abs_size(x: NDArray[Any, numpy.float32], class TestOnnxVariable(ExtTestCase): - def test_onnx_variable_abs(self): + def test_py_abs(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs(x) self.assertEqualArray(y, numpy.abs(x)) @@ -298,54 +298,54 @@ def test_onnx_variable_abs(self): self.assertTrue(hasattr(test_abs, 'compiled')) self.assertIsInstance(test_abs.compiled, ONC) - def test_onnx_variable_abs_add(self): + def test_py_abs_add(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_add(x) self.assertEqualArray(y, numpy.abs(x) + x) - def test_onnx_variable_abs_addm(self): + def test_py_abs_addm(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_addm(x, x) self.assertEqualArray(y, numpy.abs(x) + x) - def test_onnx_variable_abs_add_cst(self): + def test_py_abs_add_cst(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_add2(x) self.assertEqualArray(y, numpy.abs(x) + 2) - def test_onnx_variable_abs_add4(self): + def test_py_abs_add4(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_add4(x) text = str(test_abs_add4.compiled.onnx_).split('op_type: "Mul"') self.assertEqual(len(text), 3) self.assertEqualArray(y, (x * x) * (x * x)) - def test_onnx_variable_abs_sub(self): + def test_py_abs_sub(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_sub(x) self.assertEqualArray(y, numpy.abs(x) - x) - def test_onnx_variable_abs_mul(self): + def test_py_abs_mul(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_mul(x) self.assertEqualArray(y, numpy.abs(x) * x) - def test_onnx_variable_abs_mod(self): + def test_py_abs_mod(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_mod(x) self.assertEqualArray(y, numpy.abs(x) % 2) - def test_onnx_variable_abs_pox(self): + def test_py_abs_pox(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_pow(x) self.assertEqualArray(y, numpy.abs(x) ** 2) - def test_onnx_variable_abs_matmul(self): + def test_py_abs_matmul(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_matmul(x) self.assertEqualArray(y, numpy.abs(x) @ x) - def test_onnx_variable_abs_div(self): + def test_py_abs_div(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_div(x) self.assertEqualArray(y, numpy.abs(x) / x) @@ -353,7 +353,7 @@ def test_onnx_variable_abs_div(self): y = test_abs_div(x) self.assertEqualArray(y, numpy.abs(x) / x) - def test_onnx_variable_abs_idiv(self): + def test_py_abs_idiv(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_idiv(x) self.assertEqualArray(y, numpy.abs(x) // x) @@ -362,69 +362,69 @@ def test_onnx_variable_abs_idiv(self): self.assertEqualArray(y, numpy.abs(x) // x) @ignore_warnings(DeprecationWarning) - def test_onnx_variable_abs_equal(self): + def test_py_abs_equal(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_equal(x) self.assertEqualArray(y, numpy.abs(x) == x) @ignore_warnings(DeprecationWarning) - def test_onnx_variable_abs_not_equal(self): + def test_py_abs_not_equal(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_not_equal(x) self.assertEqualArray(y, numpy.abs(x) != x) @ignore_warnings(DeprecationWarning) - def test_onnx_variable_abs_greater(self): + def test_py_abs_greater(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_greater(x) self.assertEqualArray(y, numpy.abs(x) > x) @ignore_warnings(DeprecationWarning) - def test_onnx_variable_abs_less(self): + def test_py_abs_less(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_less(x) self.assertEqualArray(y, numpy.abs(x) < x) @ignore_warnings(DeprecationWarning) - def test_onnx_variable_abs_and(self): + def test_py_abs_and(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_and(x) self.assertEqualArray( y, (numpy.abs(x) < x) & (numpy.abs(x) < 0)) @ignore_warnings(DeprecationWarning) - def test_onnx_variable_abs_or(self): + def test_py_abs_or(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_or(x) self.assertEqualArray( y, (numpy.abs(x) < x) | (numpy.abs(x) < 0)) - def test_onnx_variable_abs_sum1(self): + def test_py_abs_sum1(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_sum1(x) self.assertEqualArray(y, numpy.sum(numpy.abs(x), axis=0)) - def test_onnx_variable_abs_sum2(self): + def test_py_abs_sum2(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_sum2(x) self.assertEqualArray(y, numpy.sum(numpy.abs(x), axis=1, keepdims=1)) - def test_onnx_variable_transpose_t(self): + def test_py_transpose_t(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_transpose_t(x) self.assertEqualArray(y, numpy.abs(x).T) - def test_onnx_variable_abs_cast(self): + def test_py_abs_cast(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_cast(x) self.assertEqualArray(y, numpy.abs(x).astype(numpy.int64)) - def test_onnx_variable_abs_reshape(self): + def test_py_abs_reshape(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_reshape(x) self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1))) - def test_onnx_variable_abs_reshape_11(self): + def test_py_abs_reshape_11(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_reshape(x) self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1))) @@ -435,61 +435,61 @@ def test_onnx_variable_abs_reshape_11(self): compiled = test_abs_reshape_11.compiled self.assertIn("version: 11", str(compiled.onnx_)) - def test_onnx_variable_abs_slice(self): + def test_py_abs_slice(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_slice(x) self.assertEqualArray(y, numpy.abs(x)[:, 1:]) - def test_onnx_variable_abs_slice23(self): + def test_py_abs_slice23(self): x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32) y = test_abs_slice23(x) self.assertEqualArray(y, numpy.abs(x)[::2, ::3]) - def test_onnx_variable_abs_neg(self): + def test_py_abs_neg(self): x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32) y = test_abs_neg(x) self.assertEqualArray(y, -numpy.abs(x)) - def test_onnx_variable_abs_not(self): + def test_py_abs_not(self): x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32) y = test_abs_not(x) self.assertEqualArray(y, numpy.abs(x) <= 0) - def test_onnx_variable_abs_filter(self): + def test_py_abs_filter(self): x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32) y = test_abs_filter(x) self.assertEqualArray(y, numpy.abs(x)[x[:, 0] > 15]) - def test_onnx_variable_abs_set(self): + def test_py_abs_set(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_set2(x) temp = numpy.abs(x) temp[:, 0] = -1 self.assertEqualArray(y, temp) - def test_onnx_variable_abs_set3(self): + def test_py_abs_set3(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_set3(x) temp = numpy.abs(x) temp[:, 0] = -1.5 self.assertEqualArray(y, temp) - def test_onnx_variable_log(self): + def test_py_log(self): x = numpy.array([[6.1, 5], [3.5, 7.8]], dtype=numpy.float32) y = test_log(x) self.assertEqualArray(y, numpy.log(x)) - def test_onnx_variable_abs_log_multi(self): + def test_py_abs_log_multi(self): x = numpy.array([[6.1, -5], [-3.5, 7.8]], dtype=numpy.float32) y = test_abs_log_multi(x) self.assertEqualArray(y, numpy.log(numpy.abs(x))) - def test_onnx_variable_abs_shape(self): + def test_py_abs_shape(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_shape(x) self.assertEqualArray(y, numpy.abs(x).shape) - def test_onnx_variable_abs_size(self): + def test_py_abs_size(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) y = test_abs_size(x) self.assertEqualArray(y, numpy.abs(x).size) diff --git a/_unittests/ut_npy/test_onnx_variable_ort.py b/_unittests/ut_npy/test_onnx_variable_ort.py new file mode 100644 index 000000000..beabd873e --- /dev/null +++ b/_unittests/ut_npy/test_onnx_variable_ort.py @@ -0,0 +1,499 @@ +# -*- coding: utf-8 -*- +""" +@brief test log(time=3s) +""" +import unittest +from typing import Any +import numpy +from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument # pylint: disable=E0611 +from pyquickhelper.pycode import ExtTestCase, ignore_warnings +from mlprodict.npy import onnxnumpy, onnxnumpy_np +import mlprodict.npy.numpy_onnx_impl as nxnp +from mlprodict.npy import ( + OnnxNumpyCompiler as ONC, NDArray, NDArraySameTypeSameShape) + + +@ignore_warnings(DeprecationWarning) +def get_bool(unused): + try: + return numpy.bool + except AttributeError: + return bool + + +numpy_bool = get_bool(None) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy abs" + return nxnp.abs(x) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_abs(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy abs abs" + return nxnp.abs(nxnp.abs(x)) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_add(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy addition" + return nxnp.abs(x) + x + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_add4(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy addition" + x2 = x * x + return x2 * x2 + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_addm(x1: NDArray[Any, numpy.float32], + x2: NDArray[Any, numpy.float32] + ) -> NDArray[Any, numpy.float32]: + "onnx numpy addition" + return nxnp.abs(x1) + x2 + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_add2(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy addition" + return nxnp.abs(x) + numpy.float32(2) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_sub(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy addition" + return nxnp.abs(x) - x + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_mul(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy addition" + return nxnp.abs(x) * x + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_pow(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy power" + return nxnp.abs(x) ** numpy.float32(2) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_mod(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.int64]: + "onnx numpy modulo" + return nxnp.abs(x).astype(numpy.int64) % numpy.int64(2) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_matmul(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy addition" + return nxnp.abs(x) @ x + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_div(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy division" + return nxnp.abs(x) / x + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_idiv(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.int64]: + "onnx numpy int division" + return nxnp.abs(x).astype(numpy.int64) // x.astype(numpy.int64) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_equal(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy_bool]: + "onnx numpy equality" + return nxnp.abs(x) == x + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_not_equal(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy_bool]: + "onnx numpy inequality" + return nxnp.abs(x) != x + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_greater(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy_bool]: + "onnx numpy greater" + return nxnp.abs(x) > x + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_less(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy_bool]: + "onnx numpy less" + return nxnp.abs(x) < x + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_and(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy_bool]: + "onnx numpy and" + return (nxnp.abs(x) < x) and (nxnp.abs(x) < numpy.float32(0)) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_or(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy_bool]: + "onnx numpy or" + return (nxnp.abs(x) < x) or (nxnp.abs(x) < numpy.float32(0)) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_sum1(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy sum" + return nxnp.sum(nxnp.abs(x), axis=0) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_sum2(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy sum" + return nxnp.sum(nxnp.abs(x), axis=1, keepdims=1) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_transpose_t(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy transpose T" + return nxnp.abs(x).T + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_cast(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.int64]: + "onnx numpy cast" + return nxnp.abs(x).astype(numpy.int64) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_reshape(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy reshape" + return nxnp.abs(x).reshape((-1, 1)) + + +@onnxnumpy(op_version=11) +def test_abs_reshape_11(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy reshape with opset 11" + return nxnp.abs(x).reshape((-1, 1)) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_slice(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy slice 1" + return nxnp.abs(x)[:, 1] + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_slice2(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy slice 2" + return nxnp.abs(x)[:1, 1] + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_slice23(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy slice 23" + return nxnp.abs(x)[::2, ::3] + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_neg(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy neg" + return - nxnp.abs(x) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_not(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.bool]: + "onnx numpy not" + temp = nxnp.abs(x) > numpy.float32(0) + return temp.not_() + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_filter(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy filter" + return nxnp.abs(x)[x[:, 0] > numpy.float32(15)] + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_set2(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy set" + temp = nxnp.abs(x).copy() + temp[:2, 0] = numpy.float32(-1) + return temp + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_set3(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy set" + temp = nxnp.abs(x).copy() + temp[:2, :1] = numpy.array([[-1.5, -1.5]], dtype=numpy.float32).T + return temp + + +@onnxnumpy(runtime='onnxruntime1') +def test_log(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy log" + return nxnp.log(x) + + +@onnxnumpy_np(signature=NDArraySameTypeSameShape("floats"), + runtime='onnxruntime1') +def test_abs_log_multi(x): + "onnx numpy log multiple type" + return nxnp.log(nxnp.abs(x)) + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_shape(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.int64]: + "onnx numpy shape" + return nxnp.abs(x).shape + + +@onnxnumpy(runtime='onnxruntime1') +def test_abs_size(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.int64]: + "onnx numpy size" + return nxnp.abs(x).size + + +class TestOnnxVariableOrt(ExtTestCase): + + def test_ort_abs(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs(x) + self.assertEqualArray(y, numpy.abs(x)) + self.assertEqual(test_abs.__doc__, "onnx numpy abs") + self.assertTrue(hasattr(test_abs, 'compiled')) + self.assertIsInstance(test_abs.compiled, ONC) + + def test_ort_abs_add(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_add(x) + self.assertEqualArray(y, numpy.abs(x) + x) + + def test_ort_abs_addm(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_addm(x, x) + self.assertEqualArray(y, numpy.abs(x) + x) + + def test_ort_abs_add_cst(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_add2(x) + self.assertEqualArray(y, numpy.abs(x) + 2) + + def test_ort_abs_add4(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_add4(x) + text = str(test_abs_add4.compiled.onnx_).split('op_type: "Mul"') + self.assertEqual(len(text), 3) + self.assertEqualArray(y, (x * x) * (x * x)) + + def test_ort_abs_sub(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_sub(x) + self.assertEqualArray(y, numpy.abs(x) - x) + + def test_ort_abs_mul(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_mul(x) + self.assertEqualArray(y, numpy.abs(x) * x) + + def test_ort_abs_mod(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_mod(x) + self.assertEqualArray(y, numpy.abs(x).astype(numpy.int64) % 2) + + def test_ort_abs_pox(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_pow(x) + self.assertEqualArray(y, numpy.abs(x) ** 2) + + def test_ort_abs_matmul(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_matmul(x) + self.assertEqualArray(y, numpy.abs(x) @ x) + + def test_ort_abs_div(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_div(x) + self.assertEqualArray(y, numpy.abs(x) / x) + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.int64) + self.assertRaise(lambda: test_abs_div(x), InvalidArgument) + + def test_ort_abs_idiv(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_idiv(x) + self.assertEqualArray(y, numpy.abs(x) // x) + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.int64) + self.assertRaise(lambda: test_abs_idiv(x), InvalidArgument) + + @ignore_warnings(DeprecationWarning) + def test_ort_abs_equal(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_equal(x) + self.assertEqualArray(y, numpy.abs(x) == x) + + @ignore_warnings(DeprecationWarning) + def test_ort_abs_not_equal(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_not_equal(x) + self.assertEqualArray(y, numpy.abs(x) != x) + + @ignore_warnings(DeprecationWarning) + def test_ort_abs_greater(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_greater(x) + self.assertEqualArray(y, numpy.abs(x) > x) + + @ignore_warnings(DeprecationWarning) + def test_ort_abs_less(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_less(x) + self.assertEqualArray(y, numpy.abs(x) < x) + + @ignore_warnings(DeprecationWarning) + def test_ort_abs_and(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_and(x) + self.assertEqualArray( + y, (numpy.abs(x) < x) & (numpy.abs(x) < 0)) + + @ignore_warnings(DeprecationWarning) + def test_ort_abs_or(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_or(x) + self.assertEqualArray( + y, (numpy.abs(x) < x) | (numpy.abs(x) < 0)) + + def test_ort_abs_sum1(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_sum1(x) + self.assertEqualArray(y, numpy.sum(numpy.abs(x), axis=0)) + + def test_ort_abs_sum2(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_sum2(x) + self.assertEqualArray(y, numpy.sum(numpy.abs(x), axis=1, keepdims=1)) + + def test_ort_transpose_t(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_transpose_t(x) + self.assertEqualArray(y, numpy.abs(x).T) + + def test_ort_abs_cast(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_cast(x) + self.assertEqualArray(y, numpy.abs(x).astype(numpy.int64)) + + def test_ort_abs_reshape(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_reshape(x) + self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1))) + + def test_ort_abs_reshape_11(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_reshape(x) + self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1))) + compiled = test_abs_reshape.compiled + self.assertNotIn("version: 11", str(compiled.onnx_)) + y = test_abs_reshape_11(x) + self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1))) + compiled = test_abs_reshape_11.compiled + self.assertIn("version: 11", str(compiled.onnx_)) + + def test_ort_abs_slice(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_slice(x) + self.assertEqualArray(y, numpy.abs(x)[:, 1:]) + + def test_ort_abs_slice23(self): + x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32) + y = test_abs_slice23(x) + self.assertEqualArray(y, numpy.abs(x)[::2, ::3]) + + def test_ort_abs_neg(self): + x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32) + y = test_abs_neg(x) + self.assertEqualArray(y, -numpy.abs(x)) + + def test_ort_abs_not(self): + x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32) + y = test_abs_not(x) + self.assertEqualArray(y, numpy.abs(x) <= 0) + + def test_ort_abs_filter(self): + x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32) + y = test_abs_filter(x) + self.assertEqualArray(y, numpy.abs(x)[x[:, 0] > 15]) + + def test_ort_abs_set(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_set2(x) + temp = numpy.abs(x) + temp[:, 0] = -1 + self.assertEqualArray(y, temp) + + def test_ort_abs_set3(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_set3(x) + temp = numpy.abs(x) + temp[:, 0] = -1.5 + self.assertEqualArray(y, temp) + + def test_ort_log(self): + x = numpy.array([[6.1, 5], [3.5, 7.8]], dtype=numpy.float32) + y = test_log(x) + self.assertEqualArray(y, numpy.log(x)) + + def test_ort_abs_log_multi(self): + x = numpy.array([[6.1, -5], [-3.5, 7.8]], dtype=numpy.float32) + y = test_abs_log_multi(x) + self.assertEqualArray(y, numpy.log(numpy.abs(x))) + + def test_ort_abs_shape(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_shape(x) + self.assertEqualArray(y, numpy.abs(x).shape) + + def test_ort_abs_size(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_size(x) + self.assertEqualArray(y, numpy.abs(x).size) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/npy/onnx_numpy_annotation.py b/mlprodict/npy/onnx_numpy_annotation.py index 61cad3a89..cdb22f1b7 100644 --- a/mlprodict/npy/onnx_numpy_annotation.py +++ b/mlprodict/npy/onnx_numpy_annotation.py @@ -95,6 +95,8 @@ def __repr__(self): def _to_onnx_dtype(self, dtype, shape): from skl2onnx.common.data_types import _guess_numpy_type + if dtype == numpy.bool_: + dtype = numpy.bool return _guess_numpy_type(dtype, shape) def get_inputs_outputs(self, args, version): diff --git a/mlprodict/npy/onnx_variable.py b/mlprodict/npy/onnx_variable.py index 42a979974..714e4510f 100644 --- a/mlprodict/npy/onnx_variable.py +++ b/mlprodict/npy/onnx_variable.py @@ -122,7 +122,7 @@ def _custom_op_filter(self, *args, op_version=None, runtime=None, **kwargs): "Custom op 'filter' expects no arguments but got %r." % kwargs) mat, index = args cast = OnnxVar(index.astype(numpy.int64), op=OnnxSqueeze) - n1 = OnnxVar(cast, op=OnnxReduceSum, keepdims=0) + n1 = OnnxVar(cast, op=OnnxReduceSum, keepdims=1) indices = OnnxVar(cast, n1, op=OnnxTopK, select_output=1) return OnnxVar(mat, indices, op=OnnxGather)