From 8280b0eaa7173b9cb1c2a3dcb8119d537edc8e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Mon, 22 Feb 2021 19:56:32 +0100 Subject: [PATCH 1/3] First sketch to enable FunctionTransformer --- _doc/sphinxdoc/source/api/npy.rst | 32 ++++++--- .../ut_npy/test_function_transformer.py | 23 +++++- _unittests/ut_npy/test_numpy_onnx_pyrt.py | 33 +++++++++ _unittests/ut_npy/test_onnx_variable.py | 14 +++- _unittests/ut_npy/test_onnxpy.py | 2 +- mlprodict/npy/__init__.py | 3 +- .../npy/{numpy_impl.py => numpy_onnx_impl.py} | 9 ++- mlprodict/npy/numpy_onnx_pyrt.py | 33 +++++++++ mlprodict/npy/onnx_numpy_annotation.py | 70 +++++++++++++++++++ mlprodict/npy/onnx_numpy_compiler.py | 45 ++++++------ mlprodict/npy/onnx_numpy_wrapper.py | 11 +-- .../function_transformer_converters.py | 18 ++--- 12 files changed, 241 insertions(+), 52 deletions(-) create mode 100644 _unittests/ut_npy/test_numpy_onnx_pyrt.py rename mlprodict/npy/{numpy_impl.py => numpy_onnx_impl.py} (79%) create mode 100644 mlprodict/npy/numpy_onnx_pyrt.py create mode 100644 mlprodict/npy/onnx_numpy_annotation.py diff --git a/_doc/sphinxdoc/source/api/npy.rst b/_doc/sphinxdoc/source/api/npy.rst index 8505c194d..fc8e2e0b5 100644 --- a/_doc/sphinxdoc/source/api/npy.rst +++ b/_doc/sphinxdoc/source/api/npy.rst @@ -4,6 +4,12 @@ Numpy revisited with ONNX ========================= +.. contents:: + :local: + +Introduction +++++++++++++ + Converting custom code into :epkg:`ONNX` is not necessarily easy. One big obstacle is :epkg:`ONNX` does not represent all numpy functions with a single operator. One possible option is to provide a @@ -22,7 +28,7 @@ is called. import numpy from typing import Any from mlprodict.npy import onnxnumpy_default, NDArray - import mlprodict.npy.numpy_impl as nxnp + import mlprodict.npy.numpy_onnx_impl as nxnp @onnxnumpy_default def custom_fct(x: NDArray[Any, numpy.float32], @@ -40,13 +46,10 @@ as opposed to numpy. This approach is similar to what :epkg:`tensorflow` with `autograph `_. -.. contents:: - :local: - NDArray +++++++ -.. autosignature:: mlprodict.npy.onnx_numpy_compiler.NDArray +.. autosignature:: mlprodict.npy.onnx_numpy_annotation.NDArray :members: onnxnumpy @@ -68,9 +71,20 @@ OnnxVar .. autosignature:: mlprodict.npy.onnx_variable.OnnxVar :members: -Available numpy functions -+++++++++++++++++++++++++ +Available numpy functions implemented with ONNX operators ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +.. autosignature:: mlprodict.npy.numpy_onnx_impl.abs + +.. autosignature:: mlprodict.npy.numpy_onnx_impl.log + +.. autosignature:: mlprodict.npy.numpy_onnx_impl.sum + +ONNX functions executed python ONNX runtime ++++++++++++++++++++++++++++++++++++++++++++ + +.. autosignature:: mlprodict.npy.numpy_onnx_pyrt.abs -.. autosignature:: mlprodict.npy.numpy_impl.abs +.. autosignature:: mlprodict.npy.numpy_onnx_pyrt.log -.. autosignature:: mlprodict.npy.numpy_impl.sum +.. autosignature:: mlprodict.npy.numpy_onnx_pyrt.sum diff --git a/_unittests/ut_npy/test_function_transformer.py b/_unittests/ut_npy/test_function_transformer.py index 72ddd5318..770286a7c 100644 --- a/_unittests/ut_npy/test_function_transformer.py +++ b/_unittests/ut_npy/test_function_transformer.py @@ -7,12 +7,13 @@ from logging import getLogger from typing import Any import numpy -from sklearn.preprocessing import FunctionTransformer +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import FunctionTransformer, StandardScaler from pyquickhelper.pycode import ExtTestCase, ignore_warnings from mlprodict.onnx_conv import register_rewritten_operators, to_onnx from mlprodict.onnxrt import OnnxInference from mlprodict.npy import onnxnumpy_default -import mlprodict.npy.numpy_impl as nxnp +import mlprodict.npy.numpy_onnx_impl as nxnp from mlprodict.npy import NDArray @@ -50,6 +51,24 @@ def test_function_transformer(self): y_onx = oinf.run({'X': x}) self.assertEqualArray(y_exp, y_onx['variable']) + @ignore_warnings(DeprecationWarning) + def test_function_transformer_numpy_log(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + tr = make_pipeline(FunctionTransformer(numpy.log), StandardScaler()) + tr.fit(x) + self.assertRaise(lambda: to_onnx(tr, x), TypeError) + + @ignore_warnings(DeprecationWarning) + def test_function_transformer_nxnp_log(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + tr = make_pipeline(FunctionTransformer(nxnp.log), StandardScaler()) + tr.fit(x) + y_exp = tr.transform(x) + onnx_model = to_onnx(tr, x) + oinf = OnnxInference(onnx_model) + y_onx = oinf.run({'X': x}) + self.assertEqualArray(y_exp, y_onx['variable']) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_npy/test_numpy_onnx_pyrt.py b/_unittests/ut_npy/test_numpy_onnx_pyrt.py new file mode 100644 index 000000000..5109f5166 --- /dev/null +++ b/_unittests/ut_npy/test_numpy_onnx_pyrt.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +""" +@brief test log(time=3s) +""" +import unittest +import numpy +from pyquickhelper.pycode import ExtTestCase +import mlprodict.npy.numpy_onnx_pyrt as nxnpy + + +class TestNumpyOnnxFunction(ExtTestCase): + + def common_test1(self, x, npfct, nxfct, dtype): + xt = x.astype(dtype) + expected = npfct(xt) + got = nxfct(xt) + self.assertEqualArray(expected, got) + + def test_abs_float32(self): + x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) + self.common_test1(x, numpy.abs, nxnpy.abs, numpy.float32) + + def test_log_float32(self): + x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) + self.common_test1(x, numpy.log, nxnpy.log, numpy.float32) + + def test_sum_float32(self): + x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) + self.common_test1(x, numpy.sum, nxnpy.sum, numpy.float32) + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_npy/test_onnx_variable.py b/_unittests/ut_npy/test_onnx_variable.py index 535f11628..2946963ec 100644 --- a/_unittests/ut_npy/test_onnx_variable.py +++ b/_unittests/ut_npy/test_onnx_variable.py @@ -7,7 +7,7 @@ import numpy from pyquickhelper.pycode import ExtTestCase, ignore_warnings from mlprodict.npy import onnxnumpy, onnxnumpy_default -import mlprodict.npy.numpy_impl as nxnp +import mlprodict.npy.numpy_onnx_impl as nxnp from mlprodict.npy import OnnxNumpyCompiler as ONC, NDArray @@ -260,6 +260,13 @@ def test_abs_set3(x: NDArray[Any, numpy.float32], return temp +@onnxnumpy_default +def test_log(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy log" + return nxnp.log(x) + + class TestOnnxVariable(ExtTestCase): def test_onnx_variable_abs(self): @@ -446,6 +453,11 @@ def test_onnx_variable_abs_set3(self): temp[:, 0] = -1.5 self.assertEqualArray(y, temp) + def test_onnx_variable_log(self): + x = numpy.array([[6.1, 5], [3.5, 7.8]], dtype=numpy.float32) + y = test_log(x) + self.assertEqualArray(y, numpy.log(x)) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_npy/test_onnxpy.py b/_unittests/ut_npy/test_onnxpy.py index 313bd5184..65b36d700 100644 --- a/_unittests/ut_npy/test_onnxpy.py +++ b/_unittests/ut_npy/test_onnxpy.py @@ -21,7 +21,7 @@ def onnx_abs(x: NDArray[Any, numpy.float32], def test_annotation(self): cl = ONC(TestOnnxPy.onnx_abs, op_version=12) - ann = cl._parse_annotation() # pylint: disable=W0212 + ann = cl._parse_annotation(None) # pylint: disable=W0212 inputs, outputs = ann self.assertIsInstance(inputs, list) self.assertIsInstance(outputs, list) diff --git a/mlprodict/npy/__init__.py b/mlprodict/npy/__init__.py index b8d5f6b09..05e3f80fc 100644 --- a/mlprodict/npy/__init__.py +++ b/mlprodict/npy/__init__.py @@ -5,5 +5,6 @@ .. versionadded:: 0.6 """ -from .onnx_numpy_compiler import OnnxNumpyCompiler, NDArray +from .onnx_numpy_annotation import NDArray +from .onnx_numpy_compiler import OnnxNumpyCompiler from .onnx_numpy_wrapper import onnxnumpy, onnxnumpy_default diff --git a/mlprodict/npy/numpy_impl.py b/mlprodict/npy/numpy_onnx_impl.py similarity index 79% rename from mlprodict/npy/numpy_impl.py rename to mlprodict/npy/numpy_onnx_impl.py index ba663e861..bb76c922c 100644 --- a/mlprodict/npy/numpy_impl.py +++ b/mlprodict/npy/numpy_onnx_impl.py @@ -6,7 +6,9 @@ """ import numpy from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 - OnnxAbs, OnnxReduceSum) + OnnxAbs, + OnnxLog, + OnnxReduceSum) from .onnx_variable import OnnxVar @@ -15,6 +17,11 @@ def abs(x): return OnnxVar(x, op=OnnxAbs) +def log(x): + "See :epkg:`numpy:log`." + return OnnxVar(x, op=OnnxLog) + + def sum(x, axis=0, keepdims=0): "See :epkg:`numpy:sum`." return OnnxVar(x, numpy.array([axis], dtype=numpy.int64), diff --git a/mlprodict/npy/numpy_onnx_pyrt.py b/mlprodict/npy/numpy_onnx_pyrt.py new file mode 100644 index 000000000..8563a8506 --- /dev/null +++ b/mlprodict/npy/numpy_onnx_pyrt.py @@ -0,0 +1,33 @@ +""" +@file +@brief :epkg:`numpy` functions implemented with :epkg:`onnx` +and compiled with this python runtime. + +.. versionadded:: 0.6 +""" +from .onnx_numpy_annotation import ( + NDArraySameType, + NDArraySameTypeSameShape) +from .numpy_onnx_impl import ( + abs as nx_abs, + log as nx_log, + sum as nx_sum) +from .onnx_numpy_wrapper import onnxnumpy + + +@onnxnumpy(signature=NDArraySameTypeSameShape("all")) +def abs(x): + "abs" + return nx_abs(x) + + +@onnxnumpy(signature=NDArraySameTypeSameShape("floats")) +def log(x): + "log" + return nx_log(x) + + +@onnxnumpy(signature=NDArraySameType("all")) +def sum(x, axis=0, keepdims=0): + "sum" + return nx_sum(x, axis=axis, keepdims=keepdims) diff --git a/mlprodict/npy/onnx_numpy_annotation.py b/mlprodict/npy/onnx_numpy_annotation.py new file mode 100644 index 000000000..892d1f28d --- /dev/null +++ b/mlprodict/npy/onnx_numpy_annotation.py @@ -0,0 +1,70 @@ +""" +@file +@brief :epkg:`numpy` annotations. + +.. versionadded:: 0.6 +""" +from typing import TypeVar, Generic +import numpy + +Shape = TypeVar("Shape") +DType = TypeVar("DType") + + +all_dtypes = (numpy.float32, numpy.float64, + numpy.int32, numpy.int64, + numpy.uint32, numpy.uint64) + + +class NDArray(numpy.ndarray, Generic[Shape, DType]): + """ + Used to annotation ONNX numpy functions. + + .. versionadded:: 0.6 + """ + pass + + +class _NDArrayAlias: + def __init__(self, dtypes=None): + self.dtypes = dtypes + if isinstance(self.dtypes, str): + if self.dtypes == "all": + self.dtypes = all_dtypes + elif self.dtypes == "floats": + self.dtypes = (numpy.float32, numpy.float64) + elif self.dtypes == "ints": + self.dtypes = (numpy.int32, numpy.int64) + else: + raise ValueError( + "Unexpected shortcut for dtype %r." % self.dtypes) + elif isinstance(self.dtypes, (tuple, list)): + for dt in self.dtypes: + if dt not in all_dtypes: + raise TypeError( + "Unexpected type error for annotation " + "%r." % self) + + def __repr__(self): + "usual" + return "%s(%r)" % (self.__class__.__name__, self.dtypes) + + +class NDArraySameType(_NDArrayAlias): + """ + Shortcut to simplify signature description. + + :param + + .. versionadded:: 0.6 + """ + pass + + +class NDArraySameTypeSameShape(NDArraySameType): + """ + Shortcut to simplify signature description. + + .. versionadded:: 0.6 + """ + pass diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index ec015f5ba..48cce2b51 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -5,23 +5,10 @@ .. versionadded:: 0.6 """ import inspect -from typing import Any, TypeVar, Generic -import numpy +from typing import Any from ..onnxrt import OnnxInference from .onnx_variable import OnnxVar -Shape = TypeVar("Shape") -DType = TypeVar("DType") - - -class NDArray(numpy.ndarray, Generic[Shape, DType]): - """ - Used to annotation ONNX numpy functions. - - .. versionadded:: 0.6 - """ - pass - class OnnxNumpyFunction: """ @@ -96,11 +83,12 @@ class OnnxNumpyCompiler: for the latest one :param runtime: runtime to choose to execute the onnx graph, `python`, `onnxruntime`, `onnxruntime1` + :param signature: used when the function is not annotated .. versionadded:: 0.6 """ - def __init__(self, fct, op_version=None, runtime=None): + def __init__(self, fct, op_version=None, runtime=None, signature=None): if op_version is None: from skl2onnx import __max_supported_opset__ op_version = __max_supported_opset__ @@ -114,9 +102,11 @@ def __init__(self, fct, op_version=None, runtime=None): "Unexpected type for fct=%r, it must be " "function." % type(fct)) self.onnx_ = None - self.onnx_ = self._to_onnx(op_version=op_version) - self.runtime_ = self._build_runtime(op_version=op_version, - runtime=runtime) + self.onnx_ = self._to_onnx(op_version=op_version, + signature=signature) + self.runtime_ = self._build_runtime( + op_version=op_version, runtime=runtime, + signature=signature) def __repr__(self): "usual" @@ -136,10 +126,14 @@ def _to_onnx_dtype(self, dtype, shape): from skl2onnx.common.data_types import _guess_numpy_type return _guess_numpy_type(dtype, shape) - def _parse_annotation(self): + def _parse_annotation(self, signature): """ Returns the annotations for function `fct_`. """ + if signature is not None: + raise RuntimeError( + "Unexpected signature %r." % signature) + args = self.fct_.__code__.co_varnames[:self.fct_.__code__.co_argcount] annotations = self.fct_.__annotations__ inputs = [] @@ -149,7 +143,9 @@ def _parse_annotation(self): continue if a not in annotations: raise RuntimeError( - "Unable to find annotation for argument %r." % a) + "Unable to find annotation for argument %r. " + "You should annotate the arguments and the results " + "or specify a signature." % a) ann = annotations[a] shape, dtype = ann.__args__ shape = self._to_onnx_shape(shape) @@ -162,12 +158,12 @@ def _parse_annotation(self): outputs.append(('y', dtype)) return inputs, outputs - def _to_onnx(self, op_version=None): + def _to_onnx(self, op_version=None, signature=None): """ Returns the onnx graph produced by function `fct_`. """ if self.onnx_ is None and self.fct_ is not None: - inputs, outputs = self._parse_annotation() + inputs, outputs = self._parse_annotation(signature) names_in = [oi[0] for oi in inputs] names_out = [oi[0] for oi in outputs] names_var = [OnnxVar(n) for n in names_in] @@ -199,7 +195,7 @@ def _to_onnx(self, op_version=None): "Unable to get the ONNX graph.") return self.onnx_ - def _build_runtime(self, op_version=None, runtime=None): + def _build_runtime(self, op_version=None, runtime=None, signature=None): """ Creates the runtime for the :epkg:`ONNX` graph. @@ -207,9 +203,10 @@ def _build_runtime(self, op_version=None, runtime=None): for the latest one :param runtime: runtime to choose to execute the onnx graph, `python`, `onnxruntime`, `onnxruntime1` + :param signature: used when the function is not annotated """ onx = self._to_onnx(op_version=op_version) - inputs, outputs = self._parse_annotation() + inputs, outputs = self._parse_annotation(signature) if runtime != 'onnxruntime': rt = OnnxInference(onx, runtime=runtime) self.rt_fct_ = OnnxNumpyFunctionOnnxInference( diff --git a/mlprodict/npy/onnx_numpy_wrapper.py b/mlprodict/npy/onnx_numpy_wrapper.py index 491794a3d..562533b57 100644 --- a/mlprodict/npy/onnx_numpy_wrapper.py +++ b/mlprodict/npy/onnx_numpy_wrapper.py @@ -27,7 +27,7 @@ def __call__(self, *args): return self.compiled(*args) -def onnxnumpy(op_version=None, runtime=None): +def onnxnumpy(op_version=None, runtime=None, signature=None): """ Decorator to declare a function implemented using :epkg:`numpy` syntax but executed with :epkg:`ONNX` @@ -35,15 +35,17 @@ def onnxnumpy(op_version=None, runtime=None): :param op_version: :epkg:`ONNX` opset version :param runtime: see @see fct + :param signature: it should be used when the function + is not annoatated. Equivalent to `onnxnumpy(arg)(foo)`. - The decorator must be called with `onnxnumpy()`. .. versionadded:: 0.6 """ def decorator_fct(fct): - compiled = OnnxNumpyCompiler(fct, op_version=op_version, - runtime=runtime) + compiled = OnnxNumpyCompiler( + fct, op_version=op_version, runtime=runtime, + signature=signature) newclass = type( "onnxnumpy_%s" % fct.__name__, (wrapper_onnxnumpy,), {'__doc__': fct.__doc__}) @@ -59,7 +61,6 @@ def onnxnumpy_default(fct): operators. :param fct: function to wrap - :param runtime: see @see fct .. versionadded:: 0.6 """ diff --git a/mlprodict/onnx_conv/sklconv/function_transformer_converters.py b/mlprodict/onnx_conv/sklconv/function_transformer_converters.py index 4b2dddb24..c3dc377bc 100644 --- a/mlprodict/onnx_conv/sklconv/function_transformer_converters.py +++ b/mlprodict/onnx_conv/sklconv/function_transformer_converters.py @@ -39,10 +39,11 @@ def new_calculate_sklearn_function_transformer_output_shapes(operator): return if operator.raw_operator.func is not None: - raise RuntimeError("FunctionTransformer is not supported unless the " - "transform function is None (= identity). " - "You may raise an issue at " - "https://github.com/onnx/sklearn-onnx/issues.") + raise TypeError("FunctionTransformer is not supported unless the " + "transform function is None (= identity) or " + "wrapped with onxnumpy. " + "You may raise an issue at " + "https://github.com/onnx/sklearn-onnx/issues.") N = operator.inputs[0].type.shape[0] C = 0 for variable in operator.inputs: @@ -114,10 +115,11 @@ def new_convert_sklearn_function_transformer(scope, operator, container): return if op.func is not None: - raise RuntimeError("FunctionTransformer is not supported unless the " - "transform function is None (= identity). " - "You may raise an issue at " - "https://github.com/onnx/sklearn-onnx/issues.") + raise TypeError("FunctionTransformer is not supported unless the " + "transform function is None (= identity) or " + "wrapped with onxnumpy. " + "You may raise an issue at " + "https://github.com/onnx/sklearn-onnx/issues.") if len(operator.inputs) == 1: apply_identity(scope, operator.inputs[0].full_name, operator.outputs[0].full_name, container) From 271e3d244d438801127d8715d69649d2f82bd46d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 23 Feb 2021 01:29:43 +0100 Subject: [PATCH 2/3] improves the design --- .../ut_npy/test_function_transformer.py | 8 +- _unittests/ut_npy/test_numpy_onnx_pyrt.py | 8 +- _unittests/ut_npy/test_onnxpy.py | 4 +- mlprodict/npy/numpy_onnx_pyrt.py | 8 +- mlprodict/npy/onnx_numpy_annotation.py | 46 ++++++++++++ mlprodict/npy/onnx_numpy_compiler.py | 75 +++++++++++++++---- mlprodict/npy/onnx_numpy_wrapper.py | 75 ++++++++++++++++++- .../function_transformer_converters.py | 44 ++++++----- 8 files changed, 224 insertions(+), 44 deletions(-) diff --git a/_unittests/ut_npy/test_function_transformer.py b/_unittests/ut_npy/test_function_transformer.py index 770286a7c..caddd373c 100644 --- a/_unittests/ut_npy/test_function_transformer.py +++ b/_unittests/ut_npy/test_function_transformer.py @@ -14,6 +14,7 @@ from mlprodict.onnxrt import OnnxInference from mlprodict.npy import onnxnumpy_default import mlprodict.npy.numpy_onnx_impl as nxnp +import mlprodict.npy.numpy_onnx_pyrt as nxnpy from mlprodict.npy import NDArray @@ -60,14 +61,15 @@ def test_function_transformer_numpy_log(self): @ignore_warnings(DeprecationWarning) def test_function_transformer_nxnp_log(self): - x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) - tr = make_pipeline(FunctionTransformer(nxnp.log), StandardScaler()) + x = numpy.array([[6.1, 5], [3.5, 7.8]], dtype=numpy.float32) + self.assertIsInstance(nxnpy.log(x), numpy.ndarray) + tr = make_pipeline(FunctionTransformer(nxnpy.log), StandardScaler()) tr.fit(x) y_exp = tr.transform(x) onnx_model = to_onnx(tr, x) oinf = OnnxInference(onnx_model) y_onx = oinf.run({'X': x}) - self.assertEqualArray(y_exp, y_onx['variable']) + self.assertEqualArray(y_exp, y_onx['variable'], decimal=5) if __name__ == "__main__": diff --git a/_unittests/ut_npy/test_numpy_onnx_pyrt.py b/_unittests/ut_npy/test_numpy_onnx_pyrt.py index 5109f5166..20c2a6cc9 100644 --- a/_unittests/ut_npy/test_numpy_onnx_pyrt.py +++ b/_unittests/ut_npy/test_numpy_onnx_pyrt.py @@ -13,7 +13,7 @@ class TestNumpyOnnxFunction(ExtTestCase): def common_test1(self, x, npfct, nxfct, dtype): xt = x.astype(dtype) expected = npfct(xt) - got = nxfct(xt) + got = nxfct[numpy.float32](xt) self.assertEqualArray(expected, got) def test_abs_float32(self): @@ -21,9 +21,13 @@ def test_abs_float32(self): self.common_test1(x, numpy.abs, nxnpy.abs, numpy.float32) def test_log_float32(self): - x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) + x = numpy.array([[6.1, 5], [3.5, 7.8]], dtype=numpy.float32) self.common_test1(x, numpy.log, nxnpy.log, numpy.float32) + def test_log_float64(self): + x = numpy.array([[6.1, 5], [3.5, 7.8]], dtype=numpy.float64) + self.common_test1(x, numpy.log, nxnpy.log, numpy.float64) + def test_sum_float32(self): x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) self.common_test1(x, numpy.sum, nxnpy.sum, numpy.float32) diff --git a/_unittests/ut_npy/test_onnxpy.py b/_unittests/ut_npy/test_onnxpy.py index 65b36d700..59680dee9 100644 --- a/_unittests/ut_npy/test_onnxpy.py +++ b/_unittests/ut_npy/test_onnxpy.py @@ -21,8 +21,8 @@ def onnx_abs(x: NDArray[Any, numpy.float32], def test_annotation(self): cl = ONC(TestOnnxPy.onnx_abs, op_version=12) - ann = cl._parse_annotation(None) # pylint: disable=W0212 - inputs, outputs = ann + ann = cl._parse_annotation(None, None) # pylint: disable=W0212 + inputs, outputs, _ = ann self.assertIsInstance(inputs, list) self.assertIsInstance(outputs, list) self.assertEqual(len(inputs), 1) diff --git a/mlprodict/npy/numpy_onnx_pyrt.py b/mlprodict/npy/numpy_onnx_pyrt.py index 8563a8506..c36e0a1f0 100644 --- a/mlprodict/npy/numpy_onnx_pyrt.py +++ b/mlprodict/npy/numpy_onnx_pyrt.py @@ -12,22 +12,22 @@ abs as nx_abs, log as nx_log, sum as nx_sum) -from .onnx_numpy_wrapper import onnxnumpy +from .onnx_numpy_wrapper import onnxnumpy_np -@onnxnumpy(signature=NDArraySameTypeSameShape("all")) +@onnxnumpy_np(signature=NDArraySameTypeSameShape("all")) def abs(x): "abs" return nx_abs(x) -@onnxnumpy(signature=NDArraySameTypeSameShape("floats")) +@onnxnumpy_np(signature=NDArraySameTypeSameShape("floats")) def log(x): "log" return nx_log(x) -@onnxnumpy(signature=NDArraySameType("all")) +@onnxnumpy_np(signature=NDArraySameType("all")) def sum(x, axis=0, keepdims=0): "sum" return nx_sum(x, axis=axis, keepdims=keepdims) diff --git a/mlprodict/npy/onnx_numpy_annotation.py b/mlprodict/npy/onnx_numpy_annotation.py index 892d1f28d..4363ba084 100644 --- a/mlprodict/npy/onnx_numpy_annotation.py +++ b/mlprodict/npy/onnx_numpy_annotation.py @@ -49,6 +49,52 @@ def __repr__(self): "usual" return "%s(%r)" % (self.__class__.__name__, self.dtypes) + def _to_onnx_dtype(self, dtype, shape): + from skl2onnx.common.data_types import _guess_numpy_type + return _guess_numpy_type(dtype, shape) + + def get_inputs_outputs(self, args, version): + """ + Returns the list of inputs, outputs. + + :param args: list of arguments + :param version: required version + :return: *tuple(inputs, outputs)*, each of them + is a list of tuple with the name and the dtype + """ + def _possible_names(): + yield 'y' + yield 'z' + yield 'o' + for i in range(0, 10000): + yield 'o%d' % i + + if version not in self.dtypes: + raise TypeError( + "Unexpected dtype %r, it should be in %r." % ( + version, self.dtypes)) + onnx_type = self._to_onnx_dtype(version, None) + inputs = [(a, onnx_type) for a in args] + names_in = set(inp[0] for inp in inputs) + name_out = None + for name in _possible_names(): + if name not in names_in: + name_out = name + break + outputs = [(name_out, onnx_type)] + return inputs, outputs + + def shape_calculator(self, dims): + """ + Returns expected dimensions given the input dimensions. + """ + if len(dims) == 0: + return None + res = [dims[0]] + for _ in dims[1:]: + res.append(None) + return res + class NDArraySameType(_NDArrayAlias): """ diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index 48cce2b51..d7bbb404a 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -84,11 +84,15 @@ class OnnxNumpyCompiler: :param runtime: runtime to choose to execute the onnx graph, `python`, `onnxruntime`, `onnxruntime1` :param signature: used when the function is not annotated + :param version: the same function can be instantiated with + different type, this parameter is None or a numpy type + if the signature allows multiple types .. versionadded:: 0.6 """ - def __init__(self, fct, op_version=None, runtime=None, signature=None): + def __init__(self, fct, op_version=None, runtime=None, signature=None, + version=None): if op_version is None: from skl2onnx import __max_supported_opset__ op_version = __max_supported_opset__ @@ -102,11 +106,18 @@ def __init__(self, fct, op_version=None, runtime=None, signature=None): "Unexpected type for fct=%r, it must be " "function." % type(fct)) self.onnx_ = None - self.onnx_ = self._to_onnx(op_version=op_version, - signature=signature) + self.onnx_ = self._to_onnx( + op_version=op_version, signature=signature, + version=version) self.runtime_ = self._build_runtime( op_version=op_version, runtime=runtime, - signature=signature) + signature=signature, version=version) + inputs, outputs, kwargs = self._parse_annotation( + signature=signature, version=version) + self.meta_ = dict(op_version=op_version, runtime=runtime, + signature=signature, version=version, + inputs=inputs, outputs=outputs, + kwargs=kwargs) def __repr__(self): "usual" @@ -126,15 +137,36 @@ def _to_onnx_dtype(self, dtype, shape): from skl2onnx.common.data_types import _guess_numpy_type return _guess_numpy_type(dtype, shape) - def _parse_annotation(self, signature): + def _parse_annotation(self, signature, version): """ Returns the annotations for function `fct_`. + + :param signature: needed if the annotation is missing, + then version might be needed to specify which type + to use if the signature allows many + :param version: version inside the many signatures possible + :return: *tuple(inputs, outputs, kwargs)*, each of them + is a list of tuple with the name and the dtype, + *kwargs* is the list of additional parameters """ + params = inspect.signature(self.fct_).parameters + args = [name for name, p in params.items() + if p.default == inspect.Parameter.empty] + kwargs = {name: p.default for name, p in params.items() + if (p.default != inspect.Parameter.empty and + name != 'op_version')} + if signature is not None: - raise RuntimeError( - "Unexpected signature %r." % signature) + inputs, outputs = signature.get_inputs_outputs(args, version) + return inputs, outputs, kwargs + + def _possible_names(): + yield 'y' + yield 'z' + yield 'o' + for i in range(0, 10000): + yield 'o%d' % i - args = self.fct_.__code__.co_varnames[:self.fct_.__code__.co_argcount] annotations = self.fct_.__annotations__ inputs = [] outputs = [] @@ -155,15 +187,22 @@ def _parse_annotation(self, signature): shape, dtype = ret.__args__ shape = self._to_onnx_shape(shape) dtype = self._to_onnx_dtype(dtype, shape) - outputs.append(('y', dtype)) - return inputs, outputs + names_in = set(inp[0] for inp in inputs) + name_out = None + for name in _possible_names(): + if name not in names_in: + name_out = name + break + outputs.append((name_out, dtype)) + return inputs, outputs, kwargs - def _to_onnx(self, op_version=None, signature=None): + def _to_onnx(self, op_version=None, signature=None, version=None): """ Returns the onnx graph produced by function `fct_`. """ if self.onnx_ is None and self.fct_ is not None: - inputs, outputs = self._parse_annotation(signature) + inputs, outputs, _ = self._parse_annotation( + signature=signature, version=version) names_in = [oi[0] for oi in inputs] names_out = [oi[0] for oi in outputs] names_var = [OnnxVar(n) for n in names_in] @@ -195,7 +234,8 @@ def _to_onnx(self, op_version=None, signature=None): "Unable to get the ONNX graph.") return self.onnx_ - def _build_runtime(self, op_version=None, runtime=None, signature=None): + def _build_runtime(self, op_version=None, runtime=None, + signature=None, version=None): """ Creates the runtime for the :epkg:`ONNX` graph. @@ -205,8 +245,13 @@ def _build_runtime(self, op_version=None, runtime=None, signature=None): `python`, `onnxruntime`, `onnxruntime1` :param signature: used when the function is not annotated """ - onx = self._to_onnx(op_version=op_version) - inputs, outputs = self._parse_annotation(signature) + onx = self._to_onnx(op_version=op_version, signature=signature, + version=version) + inputs, outputs, kwargs = self._parse_annotation( + signature=signature, version=version) + if len(kwargs) > 0: + raise NotImplementedError( + "Unable to handle additional parameters %r." % kwargs) if runtime != 'onnxruntime': rt = OnnxInference(onx, runtime=runtime) self.rt_fct_ = OnnxNumpyFunctionOnnxInference( diff --git a/mlprodict/npy/onnx_numpy_wrapper.py b/mlprodict/npy/onnx_numpy_wrapper.py index 562533b57..d21918acf 100644 --- a/mlprodict/npy/onnx_numpy_wrapper.py +++ b/mlprodict/npy/onnx_numpy_wrapper.py @@ -47,7 +47,7 @@ def decorator_fct(fct): fct, op_version=op_version, runtime=runtime, signature=signature) newclass = type( - "onnxnumpy_%s" % fct.__name__, + "onnxnumpy_%s_%s_%s" % (fct.__name__, str(op_version), runtime), (wrapper_onnxnumpy,), {'__doc__': fct.__doc__}) return newclass(compiled) @@ -65,3 +65,76 @@ def onnxnumpy_default(fct): .. versionadded:: 0.6 """ return onnxnumpy()(fct) + + +class wrapper_onnxnumpy_np: + """ + Intermediate wrapper to store a pointer + on the compiler (type: @see cl OnnxNumpyCompiler) + supporting multiple signatures. + + .. versionadded:: 0.6 + """ + + def __init__(self, **kwargs): + self.data = kwargs + self.signed_compiled = {} + + def __getitem__(self, dtype): + """ + Returns the instance of @see cl wrapper_onnxnumpy + mapped to *dtype*. + + :param dtype: numpy dtype corresponding to the input dtype + of the function + :return: instance of @see cl wrapper_onnxnumpy + """ + if dtype not in self.signed_compiled: + self._populate(dtype) + return self.signed_compiled[dtype] + + def __call__(self, *args): + """ + Calls the compiled function assuming the type of the first + tensor in *args* defines the templated version of the function + to convert into *ONNX*. + """ + return self[args[0].dtype](*args) + + def _populate(self, version): + """ + Creates the appropriate runtime for function *fct* + """ + compiled = OnnxNumpyCompiler( + fct=self.data["fct"], op_version=self.data["op_version"], + runtime=self.data["runtime"], signature=self.data["signature"], + version=version) + newclass = type( + "onnxnumpy_np_%s_%s_%s_%s" % ( + self.data["fct"].__name__, str(self.data["op_version"]), + self.data["runtime"], str(version).split('.')[-1]), + (wrapper_onnxnumpy,), {'__doc__': self.data["fct"].__doc__}) + + self.signed_compiled[version] = newclass(compiled) + + +def onnxnumpy_np(op_version=None, runtime=None, signature=None): + """ + Decorator to declare a function implemented using + :epkg:`numpy` syntax but executed with :epkg:`ONNX` + operators. + + :param op_version: :epkg:`ONNX` opset version + :param runtime: see @see fct + :param signature: it should be used when the function + is not annoatated. + + Equivalent to `onnxnumpy(arg)(foo)`. + + .. versionadded:: 0.6 + """ + def decorator_fct(fct): + return wrapper_onnxnumpy_np( + fct=fct, op_version=op_version, runtime=runtime, + signature=signature) + return decorator_fct diff --git a/mlprodict/onnx_conv/sklconv/function_transformer_converters.py b/mlprodict/onnx_conv/sklconv/function_transformer_converters.py index c3dc377bc..9b8d42eee 100644 --- a/mlprodict/onnx_conv/sklconv/function_transformer_converters.py +++ b/mlprodict/onnx_conv/sklconv/function_transformer_converters.py @@ -4,6 +4,7 @@ :epkg:`sklearn-onnx`. """ import copy +from skl2onnx.common.data_types import guess_numpy_type from skl2onnx.common._apply_operation import apply_concat, apply_identity @@ -13,8 +14,12 @@ def new_calculate_sklearn_function_transformer_output_shapes(operator): :epkg:`sklearn-onnx` to support custom functions implemented with :ref:`l-numpy-onnxpy`. """ - if hasattr(operator.raw_operator.func, 'compiled'): - compiled = operator.raw_operator.func.compiled + fct = operator.raw_operator.func + if hasattr(fct, 'signed_compiled'): + dtype = guess_numpy_type(operator.inputs[0].type) + fct = fct[dtype] + if hasattr(fct, 'compiled'): + compiled = fct.compiled if not hasattr(compiled, 'onnx_'): raise RuntimeError( # pragma: no cover "Attribute 'onnx_' is missing, function was not " @@ -30,20 +35,23 @@ def new_calculate_sklearn_function_transformer_output_shapes(operator): raise RuntimeError( # pragma: no cover "Only one output is allowed not %d." % len(outputs)) input_type = operator.inputs[0].type.__class__ - N = operator.inputs[0].type.shape[0] - dims = [N] - out = outputs[0] - if hasattr(out, 'dims'): - dims.extend(out.dims[1:]) + if compiled.meta_.get('signature', None): + dims = compiled.meta_['signature'].shape_calculator( + operator.inputs[0].type.shape) + else: + N = operator.inputs[0].type.shape[0] + dims = [N] + out = outputs[0] + if hasattr(out, 'dims'): + dims.extend(out.dims[1:]) operator.outputs[0].type = input_type(dims) return if operator.raw_operator.func is not None: raise TypeError("FunctionTransformer is not supported unless the " - "transform function is None (= identity) or " - "wrapped with onxnumpy. " - "You may raise an issue at " - "https://github.com/onnx/sklearn-onnx/issues.") + "transform function is of type %r " + "wrapped with onnxnumpy." % type( + operator.raw_operator.func)) N = operator.inputs[0].type.shape[0] C = 0 for variable in operator.inputs: @@ -63,8 +71,12 @@ def new_convert_sklearn_function_transformer(scope, operator, container): implemented with :ref:`l-numpy-onnxpy`. """ op = operator.raw_operator - if hasattr(op.func, 'compiled'): - compiled = operator.raw_operator.func.compiled + fct = op.func + if hasattr(fct, 'signed_compiled'): + dtype = guess_numpy_type(operator.inputs[0].type) + fct = fct[dtype] + if hasattr(fct, 'compiled'): + compiled = fct.compiled if not hasattr(compiled, 'onnx_'): raise RuntimeError( # pragma: no cover "Attribute 'onnx_' is missing, function was not " @@ -116,10 +128,8 @@ def new_convert_sklearn_function_transformer(scope, operator, container): if op.func is not None: raise TypeError("FunctionTransformer is not supported unless the " - "transform function is None (= identity) or " - "wrapped with onxnumpy. " - "You may raise an issue at " - "https://github.com/onnx/sklearn-onnx/issues.") + "transform function is of type %r or " + "wrapped with onnxnumpy." % type(op.func)) if len(operator.inputs) == 1: apply_identity(scope, operator.inputs[0].full_name, operator.outputs[0].full_name, container) From c5821b61aa57f1954bfd5f946031caaf73a2d7c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 23 Feb 2021 14:47:12 +0100 Subject: [PATCH 3/3] Fix functions with arguments --- _unittests/ut_npy/test_numpy_onnx_pyrt.py | 18 ++++++++--- _unittests/ut_npy/test_onnx_variable.py | 16 ++++++++-- mlprodict/npy/__init__.py | 5 +-- mlprodict/npy/numpy_onnx_impl.py | 4 ++- mlprodict/npy/numpy_onnx_pyrt.py | 2 +- mlprodict/npy/onnx_numpy_annotation.py | 32 +++++++++++++++++--- mlprodict/npy/onnx_numpy_compiler.py | 37 +++++++++++++++-------- mlprodict/npy/onnx_numpy_wrapper.py | 18 +++++++++-- 8 files changed, 102 insertions(+), 30 deletions(-) diff --git a/_unittests/ut_npy/test_numpy_onnx_pyrt.py b/_unittests/ut_npy/test_numpy_onnx_pyrt.py index 20c2a6cc9..5f358e045 100644 --- a/_unittests/ut_npy/test_numpy_onnx_pyrt.py +++ b/_unittests/ut_npy/test_numpy_onnx_pyrt.py @@ -10,10 +10,15 @@ class TestNumpyOnnxFunction(ExtTestCase): - def common_test1(self, x, npfct, nxfct, dtype): + def common_test1(self, x, npfct, nxfct, dtype, **kwargs): xt = x.astype(dtype) - expected = npfct(xt) - got = nxfct[numpy.float32](xt) + if kwargs is None or len(kwargs) == 0: + expected = npfct(xt) + got = nxfct[dtype](xt) + else: + expected = npfct(xt, **kwargs) + kwargs['dtype_onnx'] = dtype + got = nxfct[kwargs](xt) self.assertEqualArray(expected, got) def test_abs_float32(self): @@ -29,8 +34,11 @@ def test_log_float64(self): self.common_test1(x, numpy.log, nxnpy.log, numpy.float64) def test_sum_float32(self): - x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) - self.common_test1(x, numpy.sum, nxnpy.sum, numpy.float32) + kwargs = [{'axis': 0}, {}, {'axis': 1}] + for kw in kwargs: + with self.subTest(kw=kw): + x = numpy.array([[-6.1, 5], [-3.5, 7.8]], dtype=numpy.float32) + self.common_test1(x, numpy.sum, nxnpy.sum, numpy.float32, **kw) if __name__ == "__main__": diff --git a/_unittests/ut_npy/test_onnx_variable.py b/_unittests/ut_npy/test_onnx_variable.py index 2946963ec..a764e1646 100644 --- a/_unittests/ut_npy/test_onnx_variable.py +++ b/_unittests/ut_npy/test_onnx_variable.py @@ -6,9 +6,10 @@ from typing import Any import numpy from pyquickhelper.pycode import ExtTestCase, ignore_warnings -from mlprodict.npy import onnxnumpy, onnxnumpy_default +from mlprodict.npy import onnxnumpy, onnxnumpy_default, onnxnumpy_np import mlprodict.npy.numpy_onnx_impl as nxnp -from mlprodict.npy import OnnxNumpyCompiler as ONC, NDArray +from mlprodict.npy import ( + OnnxNumpyCompiler as ONC, NDArray, NDArraySameTypeSameShape) @ignore_warnings(DeprecationWarning) @@ -267,6 +268,12 @@ def test_log(x: NDArray[Any, numpy.float32], return nxnp.log(x) +@onnxnumpy_np(signature=NDArraySameTypeSameShape("floats")) +def test_abs_log_multi(x): + "onnx numpy log multiple type" + return nxnp.log(nxnp.abs(x)) + + class TestOnnxVariable(ExtTestCase): def test_onnx_variable_abs(self): @@ -458,6 +465,11 @@ def test_onnx_variable_log(self): y = test_log(x) self.assertEqualArray(y, numpy.log(x)) + def test_onnx_variable_abs_log_multi(self): + x = numpy.array([[6.1, -5], [-3.5, 7.8]], dtype=numpy.float32) + y = test_abs_log_multi(x) + self.assertEqualArray(y, numpy.log(numpy.abs(x))) + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/npy/__init__.py b/mlprodict/npy/__init__.py index 05e3f80fc..7cac3a2a6 100644 --- a/mlprodict/npy/__init__.py +++ b/mlprodict/npy/__init__.py @@ -5,6 +5,7 @@ .. versionadded:: 0.6 """ -from .onnx_numpy_annotation import NDArray +from .onnx_numpy_annotation import ( + NDArray, NDArraySameType, NDArraySameTypeSameShape) from .onnx_numpy_compiler import OnnxNumpyCompiler -from .onnx_numpy_wrapper import onnxnumpy, onnxnumpy_default +from .onnx_numpy_wrapper import onnxnumpy, onnxnumpy_default, onnxnumpy_np diff --git a/mlprodict/npy/numpy_onnx_impl.py b/mlprodict/npy/numpy_onnx_impl.py index bb76c922c..fb08dca7a 100644 --- a/mlprodict/npy/numpy_onnx_impl.py +++ b/mlprodict/npy/numpy_onnx_impl.py @@ -22,7 +22,9 @@ def log(x): return OnnxVar(x, op=OnnxLog) -def sum(x, axis=0, keepdims=0): +def sum(x, axis=None, keepdims=0): "See :epkg:`numpy:sum`." + if axis is None: + return OnnxVar(x, op=OnnxReduceSum, keepdims=keepdims) return OnnxVar(x, numpy.array([axis], dtype=numpy.int64), op=OnnxReduceSum, keepdims=keepdims) diff --git a/mlprodict/npy/numpy_onnx_pyrt.py b/mlprodict/npy/numpy_onnx_pyrt.py index c36e0a1f0..48c942fd7 100644 --- a/mlprodict/npy/numpy_onnx_pyrt.py +++ b/mlprodict/npy/numpy_onnx_pyrt.py @@ -28,6 +28,6 @@ def log(x): @onnxnumpy_np(signature=NDArraySameType("all")) -def sum(x, axis=0, keepdims=0): +def sum(x, axis=None, keepdims=0): "sum" return nx_sum(x, axis=axis, keepdims=keepdims) diff --git a/mlprodict/npy/onnx_numpy_annotation.py b/mlprodict/npy/onnx_numpy_annotation.py index 4363ba084..c9e7c1b5e 100644 --- a/mlprodict/npy/onnx_numpy_annotation.py +++ b/mlprodict/npy/onnx_numpy_annotation.py @@ -4,6 +4,8 @@ .. versionadded:: 0.6 """ +import inspect +from collections import OrderedDict from typing import TypeVar, Generic import numpy @@ -16,6 +18,22 @@ numpy.uint32, numpy.uint64) +def get_args_kwargs(fct): + """ + Extracts arguments and optional parameters of a function. + + :param fct: function + :return: arguments, OrderedDict + """ + params = inspect.signature(fct).parameters + args = [name for name, p in params.items() + if p.default == inspect.Parameter.empty] + kwargs = OrderedDict((name, p.default) for name, p in params.items() + if (p.default != inspect.Parameter.empty and + name != 'op_version')) + return args, kwargs + + class NDArray(numpy.ndarray, Generic[Shape, DType]): """ Used to annotation ONNX numpy functions. @@ -47,7 +65,8 @@ def __init__(self, dtypes=None): def __repr__(self): "usual" - return "%s(%r)" % (self.__class__.__name__, self.dtypes) + return "%s(%r)" % ( + self.__class__.__name__, self.dtypes) def _to_onnx_dtype(self, dtype, shape): from skl2onnx.common.data_types import _guess_numpy_type @@ -69,11 +88,16 @@ def _possible_names(): for i in range(0, 10000): yield 'o%d' % i - if version not in self.dtypes: + if isinstance(version, tuple): + dtype = version[0] + else: + dtype = version + + if dtype not in self.dtypes: raise TypeError( - "Unexpected dtype %r, it should be in %r." % ( + "Unexpected version %r, it should be in %r." % ( version, self.dtypes)) - onnx_type = self._to_onnx_dtype(version, None) + onnx_type = self._to_onnx_dtype(dtype, None) inputs = [(a, onnx_type) for a in args] names_in = set(inp[0] for inp in inputs) name_out = None diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index d7bbb404a..e10d22898 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -7,6 +7,7 @@ import inspect from typing import Any from ..onnxrt import OnnxInference +from .onnx_numpy_annotation import get_args_kwargs from .onnx_variable import OnnxVar @@ -149,12 +150,17 @@ def _parse_annotation(self, signature, version): is a list of tuple with the name and the dtype, *kwargs* is the list of additional parameters """ - params = inspect.signature(self.fct_).parameters - args = [name for name, p in params.items() - if p.default == inspect.Parameter.empty] - kwargs = {name: p.default for name, p in params.items() - if (p.default != inspect.Parameter.empty and - name != 'op_version')} + args, kwargs = get_args_kwargs(self.fct_) + if isinstance(version, tuple): + if len(version) - 1 != len(kwargs): + raise RuntimeError( + "Mismatch between version=%r and kwargs=%r for " + "function %r." % (version, kwargs, self.fct_)) + up = {} + for k, v in zip(kwargs, version[1:]): + up[k] = v + kwargs = kwargs.copy() + kwargs.update(up) if signature is not None: inputs, outputs = signature.get_inputs_outputs(args, version) @@ -201,15 +207,23 @@ def _to_onnx(self, op_version=None, signature=None, version=None): Returns the onnx graph produced by function `fct_`. """ if self.onnx_ is None and self.fct_ is not None: - inputs, outputs, _ = self._parse_annotation( + inputs, outputs, kwargs = self._parse_annotation( signature=signature, version=version) + if (isinstance(version, tuple) and + len(kwargs) + 1 != len(version)): + raise NotImplementedError( + "Mismatch between additional parameters %r and " + "version %r for function %r from %r." + "" % (kwargs, version, self.fct_, + getattr(self.fct_, '__module__', None))) names_in = [oi[0] for oi in inputs] names_out = [oi[0] for oi in outputs] names_var = [OnnxVar(n) for n in names_in] if 'op_version' in self.fct_.__code__.co_varnames: - onx_algebra = self.fct_(*names_in, op_version=op_version) + onx_algebra = self.fct_( + *names_in, op_version=op_version, **kwargs) else: - onx_var = self.fct_(*names_var) + onx_var = self.fct_(*names_var, **kwargs) if not hasattr(onx_var, 'to_algebra'): raise TypeError( "The function %r to convert must return an instance of " @@ -247,11 +261,8 @@ def _build_runtime(self, op_version=None, runtime=None, """ onx = self._to_onnx(op_version=op_version, signature=signature, version=version) - inputs, outputs, kwargs = self._parse_annotation( + inputs, outputs, _ = self._parse_annotation( signature=signature, version=version) - if len(kwargs) > 0: - raise NotImplementedError( - "Unable to handle additional parameters %r." % kwargs) if runtime != 'onnxruntime': rt = OnnxInference(onx, runtime=runtime) self.rt_fct_ = OnnxNumpyFunctionOnnxInference( diff --git a/mlprodict/npy/onnx_numpy_wrapper.py b/mlprodict/npy/onnx_numpy_wrapper.py index d21918acf..faaa592bb 100644 --- a/mlprodict/npy/onnx_numpy_wrapper.py +++ b/mlprodict/npy/onnx_numpy_wrapper.py @@ -4,6 +4,7 @@ .. versionadded:: 0.6 """ +from .onnx_numpy_annotation import get_args_kwargs from .onnx_numpy_compiler import OnnxNumpyCompiler @@ -77,6 +78,9 @@ class wrapper_onnxnumpy_np: """ def __init__(self, **kwargs): + self.fct = kwargs['fct'] + self.signature = kwargs['signature'] + self.args, self.kwargs = get_args_kwargs(self.fct) self.data = kwargs self.signed_compiled = {} @@ -89,9 +93,19 @@ def __getitem__(self, dtype): of the function :return: instance of @see cl wrapper_onnxnumpy """ - if dtype not in self.signed_compiled: + if isinstance(dtype, dict): + if len(self.args) == 0: + raise RuntimeError( + "Signature does not have any arguments, use directly dtypes.") + others = tuple(dtype.get(k, self.kwargs[k]) for k in self.kwargs) + key = (dtype['dtype_onnx'], ) + others + self._populate(key) + elif dtype not in self.signed_compiled: self._populate(dtype) - return self.signed_compiled[dtype] + key = dtype + else: + key = dtype + return self.signed_compiled[key] def __call__(self, *args): """