From dcb3a2c685b7eb006441cbcda02e5ccb5f6ec166 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 13 Nov 2021 16:54:57 +0100 Subject: [PATCH 01/10] first sketch for onnx_if --- _unittests/ut_npy/test_onnx_if.py | 54 +++++++++++++++++++++++++++ mlprodict/npy/numpy_onnx_impl.py | 23 ++++++++++++ mlprodict/npy/numpy_onnx_impl_body.py | 32 ++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 _unittests/ut_npy/test_onnx_if.py create mode 100644 mlprodict/npy/numpy_onnx_impl_body.py diff --git a/_unittests/ut_npy/test_onnx_if.py b/_unittests/ut_npy/test_onnx_if.py new file mode 100644 index 000000000..e3e5db5d4 --- /dev/null +++ b/_unittests/ut_npy/test_onnx_if.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" +@brief test log(time=3s) +""" +import unittest +from typing import Any +import numpy +from pyquickhelper.pycode import ExtTestCase +from mlprodict.npy import onnxnumpy +import mlprodict.npy.numpy_onnx_impl as nxnp +from mlprodict.npy import NDArray + + +class TestOnnxVariableIf(ExtTestCase): + + @staticmethod + def numpy_onnx_if(x): + y = x * 2 + z = x + 7 + if x > 0: + return x + y + return x - y + z + + @staticmethod + def fct_onnx_if(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy abs" + y = x * numpy.float32(2) + z = x + numpy.float32(7) + return nxnp.onnx_if( + x > numpy.float32(0), + nxnp.if_then_else(lambda x, y: x + y, x, y), + nxnp.if_then_else(lambda x, y, z: x - y + z, x, y, z)) + + def test_exc(self): + + self.assertRaise( + lambda: nxnp.onnx_if( + None, + nxnp.if_then_else(lambda x, y: x + y, None, None), None), + TypeError) + self.assertRaise(lambda: nxnp.onnx_if( + None, None, None), TypeError) + + def test_onnx_if(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + test_onnx_if = onnxnumpy()(TestOnnxVariableIf.fct_onnx_if) + y = test_onnx_if(x) + self.assertEqualArray( + y, TestOnnxVariableIf.numpy_onnx_if(x)) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/npy/numpy_onnx_impl.py b/mlprodict/npy/numpy_onnx_impl.py index 501845a22..d94e5ec8e 100644 --- a/mlprodict/npy/numpy_onnx_impl.py +++ b/mlprodict/npy/numpy_onnx_impl.py @@ -51,6 +51,7 @@ OnnxUnsqueeze, OnnxWhere) from .onnx_variable import OnnxVar, MultiOnnxVar as xtuple +from .numpy_onnx_impl_body import if_then_else def abs(x): @@ -347,6 +348,28 @@ def mean(x, axis=None, keepdims=0): return OnnxVar(x, op=OnnxReduceMean, keepdims=keepdims, axes=axis) +def onnx_if(condition, then_branch, else_branch): + """ + Implements a test with onnx syntax. + + :param condition: condition (@see cl OnnxVar) + :param then_branch: then branch, of type @see cl if_then_else + :param else_branch: else branch, of type @see cl if_then_else + :return: result (@see cl OnnxVar) + """ + if not isinstance(then_branch, if_then_else): + raise TypeError( + "Parameter then_branch is not of type " + "'if_then_else' but %r." % type(then_branch)) + if not isinstance(else_branch, if_then_else): + raise TypeError( + "Parameter then_branch is not of type " + "'if_then_else' but %r." % type(else_branch)) + return OnnxVar(condition, + then_branch=then_branch, + else_branch=else_branch) + + def pad(x, pads, constant_value=None, mode='constant'): """ It does not implement :epkg:`numpy:pad` but the ONNX version diff --git a/mlprodict/npy/numpy_onnx_impl_body.py b/mlprodict/npy/numpy_onnx_impl_body.py new file mode 100644 index 000000000..55e571db7 --- /dev/null +++ b/mlprodict/npy/numpy_onnx_impl_body.py @@ -0,0 +1,32 @@ +""" +@file +@brief Design to implement graph as parameter. + +.. versionadded:: 0.8 +""" + + +class OnnxGraphParameter: + """ + Class wrapping a function to make it simple as + a parameter. + + :param fct: function taking the list of inputs defined + as @see cl OnnxVar, the function returns an @see cl OnnxVar + :param inputs: list of input as @see cl OnnxVar + """ + + def __init__(self, fct, *inputs): + self.fct = fct + self.inputs = inputs + + def __repr__(self): + "usual" + return "%s(...)" % self.__class__.__name__ + + +class if_then_else(OnnxGraphParameter): + """ + Overloads class @see cl OnnxGraphParameter. + """ + pass From aa04da5fdd30fce9d0beb593d431c821e9ab20c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 13 Nov 2021 18:16:24 +0100 Subject: [PATCH 02/10] Update onnx_variable.py --- mlprodict/npy/onnx_variable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlprodict/npy/onnx_variable.py b/mlprodict/npy/onnx_variable.py index bc24a1a4c..06bbadeba 100644 --- a/mlprodict/npy/onnx_variable.py +++ b/mlprodict/npy/onnx_variable.py @@ -109,6 +109,8 @@ def _guess_dtype(self, dtype): elif hasattr(inp, 'fit'): # scikit-learn model continue + elif hasattr(inp, '_guess_dtype'): + dtypes.append(inp._guess_dtype(dtype)) else: raise TypeError( # pragma: no cover "Unexpected type for input %i type=%r." % (i, type(inp))) From 408e768306b8f4d89b8735240d99cd5221f75d84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 14 Nov 2021 00:54:14 +0100 Subject: [PATCH 03/10] design is ok --- _unittests/ut_npy/test_onnx_if.py | 16 ++-- mlprodict/npy/numpy_onnx_impl.py | 10 +-- mlprodict/npy/numpy_onnx_impl_body.py | 83 +++++++++++++++++- mlprodict/npy/onnx_variable.py | 116 ++++++++++++++------------ 4 files changed, 155 insertions(+), 70 deletions(-) diff --git a/_unittests/ut_npy/test_onnx_if.py b/_unittests/ut_npy/test_onnx_if.py index e3e5db5d4..09474591b 100644 --- a/_unittests/ut_npy/test_onnx_if.py +++ b/_unittests/ut_npy/test_onnx_if.py @@ -29,18 +29,22 @@ def fct_onnx_if(x: NDArray[Any, numpy.float32], z = x + numpy.float32(7) return nxnp.onnx_if( x > numpy.float32(0), - nxnp.if_then_else(lambda x, y: x + y, x, y), - nxnp.if_then_else(lambda x, y, z: x - y + z, x, y, z)) + then_branch=nxnp.if_then_else( + lambda x, y: x + y, x, y), + else_branch=nxnp.if_then_else( + lambda x, y, z: x - y + z, x, y, z)) - def test_exc(self): + def _test_exc(self): self.assertRaise( lambda: nxnp.onnx_if( None, - nxnp.if_then_else(lambda x, y: x + y, None, None), None), - TypeError) + then_branch=nxnp.if_then_else(lambda x, y: x + y, "DEBUG", "DEBUG"), + else_branch="DEBUG"), + (TypeError, NotImplementedError, AttributeError)) self.assertRaise(lambda: nxnp.onnx_if( - None, None, None), TypeError) + "DEBUG", then_branch="DEBUG", else_branch="DEBUG"), + (TypeError, NotImplementedError, AttributeError)) def test_onnx_if(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) diff --git a/mlprodict/npy/numpy_onnx_impl.py b/mlprodict/npy/numpy_onnx_impl.py index d94e5ec8e..55e8f14e7 100644 --- a/mlprodict/npy/numpy_onnx_impl.py +++ b/mlprodict/npy/numpy_onnx_impl.py @@ -29,7 +29,7 @@ OnnxErf, OnnxExp, OnnxFloor, - OnnxIdentity, OnnxIsNaN, + OnnxIdentity, OnnxIf, OnnxIsNaN, OnnxLog, OnnxMatMul, OnnxPad, @@ -51,7 +51,7 @@ OnnxUnsqueeze, OnnxWhere) from .onnx_variable import OnnxVar, MultiOnnxVar as xtuple -from .numpy_onnx_impl_body import if_then_else +from .numpy_onnx_impl_body import if_then_else, OnnxVarGraph def abs(x): @@ -365,9 +365,9 @@ def onnx_if(condition, then_branch, else_branch): raise TypeError( "Parameter then_branch is not of type " "'if_then_else' but %r." % type(else_branch)) - return OnnxVar(condition, - then_branch=then_branch, - else_branch=else_branch) + return OnnxVarGraph( + condition, then_branch=then_branch, + else_branch=else_branch, op=OnnxIf) def pad(x, pads, constant_value=None, mode='constant'): diff --git a/mlprodict/npy/numpy_onnx_impl_body.py b/mlprodict/npy/numpy_onnx_impl_body.py index 55e571db7..3b0fd46db 100644 --- a/mlprodict/npy/numpy_onnx_impl_body.py +++ b/mlprodict/npy/numpy_onnx_impl_body.py @@ -4,9 +4,10 @@ .. versionadded:: 0.8 """ +from .onnx_variable import OnnxVar -class OnnxGraphParameter: +class AttributeGraph: """ Class wrapping a function to make it simple as a parameter. @@ -14,19 +15,93 @@ class OnnxGraphParameter: :param fct: function taking the list of inputs defined as @see cl OnnxVar, the function returns an @see cl OnnxVar :param inputs: list of input as @see cl OnnxVar + + .. versionadded:: 0.8 """ def __init__(self, fct, *inputs): self.fct = fct self.inputs = inputs + self.alg_ = None def __repr__(self): "usual" return "%s(...)" % self.__class__.__name__ + def _guess_dtype(self, dtype, from_init=False): + if not hasattr(self, 'onnx_') or from_init: + return None + raise NotImplementedError( + "Type=%r, dtype=%r." % (type(self), dtype)) + + def to_algebra(self, op_version=None): + """ + Converts the variable into an operator. + """ + if self.alg_ is not None: + return self.alg_ + + var = self.fct(*self.inputs) + if not isinstance(var, OnnxVar): + raise RuntimeError( # pragma: no cover + "var is not from type OnnxVar but %r." % type(var)) + + self.alg_ = var.to_algebra(op_version=op_version) + return self.alg_ + -class if_then_else(OnnxGraphParameter): +class OnnxVarGraph(OnnxVar): """ - Overloads class @see cl OnnxGraphParameter. + Overloads @see cl OnnxVar to handle graph attribute. + + :param inputs: variable name or object + :param op: :epkg:`ONNX` operator + :param select_output: if multiple output are returned by + ONNX operator *op*, it takes only one specifed by this + argument + :param dtype: specifies the type of the variable + held by this class (*op* is None) in that case + :param fields: list of attributes with the graph type + :param kwargs: addition argument to give operator *op* + + .. versionadded:: 0.8 """ - pass + + def __init__(self, *inputs, op=None, select_output=None, + dtype=None, **kwargs): + OnnxVar.__init__( + self, *inputs, op=op, select_output=select_output, + dtype=dtype, **kwargs) + + def to_algebra(self, op_version=None): + """ + Converts the variable into an operator. + """ + if self.alg_ is not None: + return self.alg_ + + # Conversion of graph attributes from InputGraph + # ONNX graph. + updates = dict() + for k, v in self.onnx_op_kwargs.items(): + if not isinstance(v, AttributeGraph): + continue + alg = v.to_algebra(op_version=op_version) + # dtypes = [i._guess_dtype(None) for i in v.inputs] + onx = alg.to_onnx(target_opset=op_version) + updates[name] = onx.graph + removed.append(i) + self.onnx_op_kwargs_before = { + k: self.onnx_op_kwargs[k] for k in updates} + self.onnx_op_kwargs.update(updates) + + return OnnxVar.to_algebra(self, op_version=op_version) + + +class if_then_else(AttributeGraph): + """ + Overloads class @see cl OnnxVarGraph. + """ + + def __init__(self, fct, *inputs): + AttributeGraph.__init__(self, fct, *inputs) diff --git a/mlprodict/npy/onnx_variable.py b/mlprodict/npy/onnx_variable.py index 32a52eeb2..373c92f20 100644 --- a/mlprodict/npy/onnx_variable.py +++ b/mlprodict/npy/onnx_variable.py @@ -73,13 +73,14 @@ def __init__(self, *inputs, op=None, select_output=None, "Unexpected type for input %d - %r." % (i, inp)) if not isinstance(inp, numpy.ndarray): continue - if inp.size > 0 and isinstance(inp.ravel()[0], (numpy.ndarray, OnnxVar)): + if (inp.size > 0 and + isinstance(inp.ravel()[0], (numpy.ndarray, OnnxVar))): raise TypeError( # pragma: no cover "Unexpected type for input %d: %r, %r." "" % (i, type(inp), inp.ravel()[0])) - self.dtype = self._guess_dtype(dtype) + self.dtype = self._guess_dtype(dtype, from_init=True) - def _guess_dtype(self, dtype): + def _guess_dtype(self, dtype, from_init=False): "Guesses dtype when not specified." if dtype is not None: return dtype @@ -96,8 +97,8 @@ def _guess_dtype(self, dtype): dtypes.append(inp.dtype) elif isinstance(inp, MultiOnnxVar): dtypes.append(inp._guess_dtype(dtype)) - elif isinstance(inp, (numpy.float32, numpy.float64, numpy.int32, - numpy.int64)): + elif isinstance(inp, (numpy.float32, numpy.float64, + numpy.int32, numpy.int64)): dtypes.append(inp.dtype) elif isinstance(inp, numpy_str): dtypes.append(numpy_str) @@ -144,56 +145,59 @@ def to_algebra(self, op_version=None): """ Converts the variable into an operator. """ - if self.alg_ is None: - if self.onnx_op is None: - if len(self.inputs) != 1: + if self.alg_ is not None: + return self.alg_ + + if self.onnx_op is None: + if len(self.inputs) != 1: + raise RuntimeError( # pragma: no cover + "Unexpected number of inputs, 1 expected, " + "got {} instead.".format(self.inputs)) + if self.dtype is None or hasattr(self.inputs[0], 'onnx_name'): + self.alg_ = self.inputs[0] + else: + self.alg_ = ( + self.inputs[0], _guess_numpy_type(self.dtype, None)) + else: + if isinstance(self.onnx_op, str): + var = self._custom_op(*self.inputs, op_version=op_version, + **self.onnx_op_kwargs) + alg = var.to_algebra(op_version=op_version) + if not hasattr(self, 'alg_'): raise RuntimeError( # pragma: no cover - "Unexpected number of inputs, 1 expected, " - "got {} instead.".format(self.inputs)) - if self.dtype is None or hasattr(self.inputs[0], 'onnx_name'): - self.alg_ = self.inputs[0] + "Missing attribute 'alg_'.") + self.alg_ = alg + return alg + + new_inputs = [] + for inp in self.inputs: + if hasattr(inp, 'fit'): + # scikit-learn model + new_inputs.append(inp) + elif isinstance(inp, ( + int, float, str, numpy.ndarray, numpy.int32, + numpy.int64, numpy.float32, numpy.float64, + numpy_bool, numpy_str, numpy.int8, numpy.uint8, + numpy.int16, numpy.uint16, numpy.uint32, + numpy.uint64)): + if (inp.size > 0 and + isinstance( + inp.ravel()[0], # pylint: disable=E1101 + (numpy.ndarray, OnnxVar))): + raise TypeError( # pragma: no cover + "Unexpected type for an input %r, %r." + "" % (type(inp), inp.ravel()[0])) # pylint: disable=E1101 + new_inputs.append(inp) else: - self.alg_ = ( - self.inputs[0], _guess_numpy_type(self.dtype, None)) + new_inputs.append( + inp.to_algebra(op_version=op_version)) + + res = self.onnx_op(*new_inputs, op_version=op_version, + **self.onnx_op_kwargs) + if self.select_output is None: + self.alg_ = res else: - if isinstance(self.onnx_op, str): - var = self._custom_op(*self.inputs, op_version=op_version, - **self.onnx_op_kwargs) - alg = var.to_algebra(op_version=op_version) - if not hasattr(self, 'alg_'): - raise RuntimeError( # pragma: no cover - "Missing attribute 'alg_'.") - self.alg_ = alg - return alg - - new_inputs = [] - for inp in self.inputs: - if hasattr(inp, 'fit'): - # scikit-learn model - new_inputs.append(inp) - elif isinstance(inp, ( - int, float, str, numpy.ndarray, numpy.int32, - numpy.int64, numpy.float32, numpy.float64, - numpy_bool, numpy_str, numpy.int8, numpy.uint8, - numpy.int16, numpy.uint16, numpy.uint32, numpy.uint64)): - if (inp.size > 0 and - isinstance( - inp.ravel()[0], # pylint: disable=E1101 - (numpy.ndarray, OnnxVar))): - raise TypeError( # pragma: no cover - "Unexpected type for an input %r, %r." - "" % (type(inp), inp.ravel()[0])) # pylint: disable=E1101 - new_inputs.append(inp) - else: - new_inputs.append( - inp.to_algebra(op_version=op_version)) - - res = self.onnx_op(*new_inputs, op_version=op_version, - **self.onnx_op_kwargs) - if self.select_output is None: - self.alg_ = res - else: - self.alg_ = res[self.select_output] + self.alg_ = res[self.select_output] return self.alg_ def _custom_op(self, *args, op_version=None, runtime=None, **kwargs): @@ -684,8 +688,9 @@ def output_names(self, value): hasattr(self.unique, 'add_to')): if len(value) > 1: self.values = tuple( - OnnxIdentity(self.unique[i], output_names=value[i:i + 1], - op_version=self.unique.op_version) + OnnxIdentity( + self.unique[i], output_names=value[i:i + 1], + op_version=self.unique.op_version) for i in range(0, len(value))) self.unique = None return @@ -782,7 +787,8 @@ def to_algebra(self, op_version=None): int, float, str, numpy.ndarray, numpy.int32, numpy.int64, numpy.float32, numpy.float64, numpy_bool, numpy_str, numpy.int8, numpy.uint8, - numpy.int16, numpy.uint16, numpy.uint32, numpy.uint64)): + numpy.int16, numpy.uint16, numpy.uint32, + numpy.uint64)): new_inputs.append(inp) elif hasattr(inp, 'fit'): # scikit-learn models From 0151098a9ec55665020bc9a42e275290b21084b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 14 Nov 2021 22:59:43 +0100 Subject: [PATCH 04/10] design improvment --- _unittests/ut_npy/test_onnx_if.py | 9 +++-- mlprodict/npy/numpy_onnx_impl_body.py | 50 ++++++++++++++++++++------- mlprodict/npy/onnx_numpy_compiler.py | 31 +++++++++++++++++ mlprodict/npy/onnx_variable.py | 8 +++++ 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/_unittests/ut_npy/test_onnx_if.py b/_unittests/ut_npy/test_onnx_if.py index 09474591b..e72f9178c 100644 --- a/_unittests/ut_npy/test_onnx_if.py +++ b/_unittests/ut_npy/test_onnx_if.py @@ -30,16 +30,17 @@ def fct_onnx_if(x: NDArray[Any, numpy.float32], return nxnp.onnx_if( x > numpy.float32(0), then_branch=nxnp.if_then_else( - lambda x, y: x + y, x, y), + lambda x, y: x / y, x, y), else_branch=nxnp.if_then_else( - lambda x, y, z: x - y + z, x, y, z)) + lambda x, y, z: x - y - z, x, y, z)) def _test_exc(self): self.assertRaise( lambda: nxnp.onnx_if( None, - then_branch=nxnp.if_then_else(lambda x, y: x + y, "DEBUG", "DEBUG"), + then_branch=nxnp.if_then_else( + lambda x, y: x + y, "DEBUG", "DEBUG"), else_branch="DEBUG"), (TypeError, NotImplementedError, AttributeError)) self.assertRaise(lambda: nxnp.onnx_if( @@ -49,6 +50,8 @@ def _test_exc(self): def test_onnx_if(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) test_onnx_if = onnxnumpy()(TestOnnxVariableIf.fct_onnx_if) + with open("debug.onnx", "wb") as f: + f.write(test_onnx_if.compiled.onnx_.SerializeToString()) y = test_onnx_if(x) self.assertEqualArray( y, TestOnnxVariableIf.numpy_onnx_if(x)) diff --git a/mlprodict/npy/numpy_onnx_impl_body.py b/mlprodict/npy/numpy_onnx_impl_body.py index 3b0fd46db..ddbd62422 100644 --- a/mlprodict/npy/numpy_onnx_impl_body.py +++ b/mlprodict/npy/numpy_onnx_impl_body.py @@ -4,6 +4,8 @@ .. versionadded:: 0.8 """ +import numpy +from skl2onnx.common.data_types import FloatTensorType from .onnx_variable import OnnxVar @@ -28,11 +30,28 @@ def __repr__(self): "usual" return "%s(...)" % self.__class__.__name__ - def _guess_dtype(self, dtype, from_init=False): - if not hasattr(self, 'onnx_') or from_init: - return None - raise NotImplementedError( - "Type=%r, dtype=%r." % (type(self), dtype)) + def _graph_guess_dtype(self, i, var): + """ + Guesses the graph inputs. + + :param i: attribute index (integer) + :param var: the input (@see cl OnnxVar) + :return: input type + """ + dtype = var._guess_dtype(None) + if dtype is None: + dtype = numpy.float32 + + if dtype == numpy.float32: + skl2onnx_type = FloatTensorType() + else: + raise TypeError( + "Unexpected type %r." % dtype) + + input_type = ('graph_%d_%d' % (id(self), i), + skl2onnx_type) + var.set_onnx_name(input_type) + return input_type, OnnxVar(input_type[0], dtype=dtype) def to_algebra(self, op_version=None): """ @@ -41,7 +60,11 @@ def to_algebra(self, op_version=None): if self.alg_ is not None: return self.alg_ - var = self.fct(*self.inputs) + new_inputs = [self._graph_guess_dtype(i, inp) + for i, inp in enumerate(self.inputs)] + self.alg_inputs_ = new_inputs + vars = [v[1] for v in new_inputs] + var = self.fct(*vars) if not isinstance(var, OnnxVar): raise RuntimeError( # pragma: no cover "var is not from type OnnxVar but %r." % type(var)) @@ -83,14 +106,15 @@ def to_algebra(self, op_version=None): # Conversion of graph attributes from InputGraph # ONNX graph. updates = dict() - for k, v in self.onnx_op_kwargs.items(): - if not isinstance(v, AttributeGraph): + self.alg_hidden_var_ = {} + for att, var in self.onnx_op_kwargs.items(): + if not isinstance(var, AttributeGraph): continue - alg = v.to_algebra(op_version=op_version) - # dtypes = [i._guess_dtype(None) for i in v.inputs] - onx = alg.to_onnx(target_opset=op_version) - updates[name] = onx.graph - removed.append(i) + alg = var.to_algebra(op_version=op_version) + onnx_inputs = [i[0] for i in var.alg_inputs_] + onx = alg.to_onnx(onnx_inputs, target_opset=op_version) + updates[att] = onx.graph + self.alg_hidden_var_[id(var)] = var self.onnx_op_kwargs_before = { k: self.onnx_op_kwargs[k] for k in updates} self.onnx_op_kwargs.update(updates) diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index 414f43884..5ff2064fe 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -306,6 +306,29 @@ def _possible_names(): return (inputs, outputs, kwargs, 0, signature.n_variables if signature is not None else False) + def _find_hidden_algebras(self, onx_var, onx_algebra): + """ + Subgraph are using inputs not linked to the others nodes. + This function retrieves them as they are stored in + attributes `alg_hidden_var_`. The function looks into every + node linked to the inputs and their predecessors. + + :param onx_var: @see cl OnnxVar + :param onx_algebra: OnnxOperator + :return: dictionary `{id(obj): obj}` + """ + keep_hidden = {} + stack = [onx_var] + while len(stack) > 0: + var = stack.pop() + hidden = getattr(var, 'alg_hidden_var_', None) + if hidden is not None: + keep_hidden.update(hidden) + if hasattr(var, 'inputs'): + for inp in var.inputs: + stack.append(inp) + return keep_hidden + def _to_onnx(self, op_version=None, signature=None, version=None): """ Returns the onnx graph produced by function `fct_`. @@ -338,6 +361,14 @@ def _to_onnx(self, op_version=None, signature=None, version=None): "OnnxVar but returns type %r." % (self.fct_, type(onx_var))) onx_algebra = onx_var.to_algebra(op_version=op_version) + hidden_algebras = self._find_hidden_algebras( + onx_var, onx_algebra) + if len(hidden_algebras) > 0: + import pprint + pprint.pprint(hidden_algebras) + raise NotImplementedError( + "Not implemented yet.") + if isinstance(onx_algebra, str): raise RuntimeError( # pragma: no cover "Unexpected str type %r." % onx_algebra) diff --git a/mlprodict/npy/onnx_variable.py b/mlprodict/npy/onnx_variable.py index 373c92f20..3eb16ad7b 100644 --- a/mlprodict/npy/onnx_variable.py +++ b/mlprodict/npy/onnx_variable.py @@ -141,6 +141,14 @@ def __repr__(self): res = "%s(%s)" % (self.__class__.__name__, ", ".join(args)) return res + def set_onnx_name(self, name_type): + """ + Forces this variable to get this name during + + :param name_type: a tuple *(name, type)* + """ + self.onnx_input_type_ = name_type + def to_algebra(self, op_version=None): """ Converts the variable into an operator. From 48564f71007d5c0098266231738b360b71728853 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 20 Nov 2021 11:45:35 +0100 Subject: [PATCH 05/10] documentation --- _doc/sphinxdoc/source/conf.py | 15 +++++++++++++++ mlprodict/asv_benchmark/_create_asv_helper.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/_doc/sphinxdoc/source/conf.py b/_doc/sphinxdoc/source/conf.py index ed28d3367..2fdb824e2 100644 --- a/_doc/sphinxdoc/source/conf.py +++ b/_doc/sphinxdoc/source/conf.py @@ -56,6 +56,21 @@ mathdef_link_only = True +intersphinx_mapping.update({ + 'cpyquickhelper': ( + 'http://www.xavierdupre.fr/app/cpyquickhelper/helpsphinx/', None), + 'jyquickhelper': ( + 'http://www.xavierdupre.fr/app/jyquickhelper/helpsphinx/', None), + 'lightgbm': ('https://lightgbm.readthedocs.io/en/latest/', None), + 'mlinsights': ( + 'http://www.xavierdupre.fr/app/mlinsights/helpsphinx/', None), + 'onnxmltools': ( + 'http://www.xavierdupre.fr/app/onnxmltools/helpsphinx/', None), + 'onnxruntime': ( + 'http://www.xavierdupre.fr/app/onnxruntime/helpsphinx/', None), + 'skl2onnx': ('http://onnx.ai/sklearn-onnx/', None), +)} + epkg_dictionary.update({ '_PredictScorer': 'https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/scorer.py#L168', 'airspeed-velocity': 'https://github.com/airspeed-velocity/asv', diff --git a/mlprodict/asv_benchmark/_create_asv_helper.py b/mlprodict/asv_benchmark/_create_asv_helper.py index b51efcd4c..4c8db72cb 100644 --- a/mlprodict/asv_benchmark/_create_asv_helper.py +++ b/mlprodict/asv_benchmark/_create_asv_helper.py @@ -169,6 +169,8 @@ def _sklearn_subfolder(model): mod = model.__module__ if mod is not None and mod.startswith('mlinsights'): return ['mlinsights', model.__name__] # pragma: no cover + if mod is not None and mod.startswith('skl2onnx.sklapi'): + return ['skl2onnx.sklapi', model.__name__] # pragma: no cover spl = mod.split('.') try: pos = spl.index('sklearn') From 215895cbe835b9db07f0d21fa908fe6de2b90236 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 20 Nov 2021 22:50:18 +0100 Subject: [PATCH 06/10] Update test_onnx_if.py --- _unittests/ut_npy/test_onnx_if.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_npy/test_onnx_if.py b/_unittests/ut_npy/test_onnx_if.py index e72f9178c..e8541a6ee 100644 --- a/_unittests/ut_npy/test_onnx_if.py +++ b/_unittests/ut_npy/test_onnx_if.py @@ -34,7 +34,7 @@ def fct_onnx_if(x: NDArray[Any, numpy.float32], else_branch=nxnp.if_then_else( lambda x, y, z: x - y - z, x, y, z)) - def _test_exc(self): + def test_exc(self): self.assertRaise( lambda: nxnp.onnx_if( From 488b98f1180106785d2904e939287f267946e5ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 20 Nov 2021 23:09:55 +0100 Subject: [PATCH 07/10] Update onnx_numpy_compiler.py --- mlprodict/npy/onnx_numpy_compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index 5ff2064fe..2949d8ba4 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -351,6 +351,7 @@ def _to_onnx(self, op_version=None, signature=None, version=None): for n, dt in zip(names_in, inputs)] if 'op_version' in self.fct_.__code__.co_varnames: + onx_var = None onx_algebra = self.fct_( *names_in, op_version=op_version, **kwargs) else: From ecc0742f20188606f8b44eea296c869346ad410f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 24 Nov 2021 10:27:40 +0100 Subject: [PATCH 08/10] update --- _unittests/ut_npy/test_onnx_if.py | 5 +++-- mlprodict/npy/onnx_numpy_compiler.py | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/_unittests/ut_npy/test_onnx_if.py b/_unittests/ut_npy/test_onnx_if.py index e8541a6ee..389c2041b 100644 --- a/_unittests/ut_npy/test_onnx_if.py +++ b/_unittests/ut_npy/test_onnx_if.py @@ -27,14 +27,15 @@ def fct_onnx_if(x: NDArray[Any, numpy.float32], "onnx numpy abs" y = x * numpy.float32(2) z = x + numpy.float32(7) - return nxnp.onnx_if( + xif = nxnp.onnx_if( x > numpy.float32(0), then_branch=nxnp.if_then_else( lambda x, y: x / y, x, y), else_branch=nxnp.if_then_else( lambda x, y, z: x - y - z, x, y, z)) + return xif + numpy.float32(-7) - def test_exc(self): + def _test_exc(self): self.assertRaise( lambda: nxnp.onnx_if( diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index 2949d8ba4..8ef85728a 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -315,19 +315,22 @@ def _find_hidden_algebras(self, onx_var, onx_algebra): :param onx_var: @see cl OnnxVar :param onx_algebra: OnnxOperator - :return: dictionary `{id(obj): obj}` + :return: tuple(dictionary `{id(obj): (var, obj)}`, + all instance of @see cl OnnxVarGraph) """ keep_hidden = {} + var_graphs = [] stack = [onx_var] while len(stack) > 0: var = stack.pop() hidden = getattr(var, 'alg_hidden_var_', None) if hidden is not None: keep_hidden.update(hidden) + var_graphs.append(var) if hasattr(var, 'inputs'): for inp in var.inputs: stack.append(inp) - return keep_hidden + return keep_hidden, var_graphs def _to_onnx(self, op_version=None, signature=None, version=None): """ @@ -362,11 +365,16 @@ def _to_onnx(self, op_version=None, signature=None, version=None): "OnnxVar but returns type %r." % (self.fct_, type(onx_var))) onx_algebra = onx_var.to_algebra(op_version=op_version) - hidden_algebras = self._find_hidden_algebras( + hidden_algebras, var_graphs = self._find_hidden_algebras( onx_var, onx_algebra) if len(hidden_algebras) > 0: - import pprint - pprint.pprint(hidden_algebras) + for gr in var_graphs: + print(type(gr), dir(gr)) + for k, v in hidden_algebras.items(): + print("*", type(v.alg_), dir(v.alg_)) + import pprint + pprint.pprint(dir(v.alg_)) + raise NotImplementedError( "Not implemented yet.") From 1037aa4bb2322769aab5c9d4b76cc6d09848784b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 26 Nov 2021 02:46:40 +0100 Subject: [PATCH 09/10] fix case when everything is constant --- _unittests/ut_npy/test_onnx_if.py | 39 ++++++++++++++----- .../ut_onnxrt/test_onnxrt_python_runtime_.py | 3 +- .../test_onnxrt_python_runtime_control_if.py | 21 ++++++---- .../ut_tools/test_optim_onnx_identity.py | 2 +- mlprodict/npy/numpy_onnx_impl.py | 4 ++ mlprodict/npy/numpy_onnx_impl_body.py | 21 +++++++++- mlprodict/npy/onnx_numpy_compiler.py | 23 ++++++----- 7 files changed, 82 insertions(+), 31 deletions(-) diff --git a/_unittests/ut_npy/test_onnx_if.py b/_unittests/ut_npy/test_onnx_if.py index 389c2041b..29266005a 100644 --- a/_unittests/ut_npy/test_onnx_if.py +++ b/_unittests/ut_npy/test_onnx_if.py @@ -17,25 +17,36 @@ class TestOnnxVariableIf(ExtTestCase): def numpy_onnx_if(x): y = x * 2 z = x + 7 - if x > 0: + if x.sum() > 0: return x + y return x - y + z @staticmethod - def fct_onnx_if(x: NDArray[Any, numpy.float32], - ) -> NDArray[Any, numpy.float32]: + def fct_onnx_if_sub(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: "onnx numpy abs" y = x * numpy.float32(2) z = x + numpy.float32(7) xif = nxnp.onnx_if( - x > numpy.float32(0), + nxnp.sum(x) > numpy.float32(0), then_branch=nxnp.if_then_else( lambda x, y: x / y, x, y), else_branch=nxnp.if_then_else( lambda x, y, z: x - y - z, x, y, z)) return xif + numpy.float32(-7) - def _test_exc(self): + @staticmethod + def fct_onnx_if(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.float32]: + "onnx numpy abs" + xif = nxnp.onnx_if( + nxnp.sum(x) > numpy.float32(0), + then_branch=nxnp.if_then_else( + numpy.array([-1], dtype=numpy.float32)), + else_branch=numpy.array([1], dtype=numpy.float32)) + return xif + numpy.float32(-7) + + def test_exc(self): self.assertRaise( lambda: nxnp.onnx_if( @@ -50,12 +61,22 @@ def _test_exc(self): def test_onnx_if(self): x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) - test_onnx_if = onnxnumpy()(TestOnnxVariableIf.fct_onnx_if) + fct_if = onnxnumpy()(TestOnnxVariableIf.fct_onnx_if) + with open("debug.onnx", "wb") as f: + f.write(fct_if.compiled.onnx_.SerializeToString()) + y = fct_if(x) + self.assertEqualArray( + y, numpy.array([-6], dtype=numpy.float32)) + + @unittest.skipIf(True, reason="does not work yet") + def test_onnx_if_sub(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + fct_if = onnxnumpy()(TestOnnxVariableIf.fct_onnx_if_sub) with open("debug.onnx", "wb") as f: - f.write(test_onnx_if.compiled.onnx_.SerializeToString()) - y = test_onnx_if(x) + f.write(fct_if.compiled.onnx_.SerializeToString()) + y = fct_if(x) self.assertEqualArray( - y, TestOnnxVariableIf.numpy_onnx_if(x)) + y, TestOnnxVariableIf.fct_onnx_if_sub(x)) if __name__ == "__main__": diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index ce22285aa..2aecb0aca 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -4240,8 +4240,7 @@ def test_make_constant(self): self.assertIsInstance(ope, expected_type[opset]) got = oinf.run({'X': X}) if opset >= 11: - self.assertEqual(list(sorted(got)), [ - 'Ad_C0', 'Co_output0']) + self.assertEqual(list(sorted(got)), ['Ad_C0']) self.assertEqualArray(exp, got['Ad_C0']) else: self.assertEqual(list(sorted(got)), ['Ad_C0']) diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_if.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_if.py index 9f72ca00e..880522a83 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_if.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_if.py @@ -5,7 +5,7 @@ from logging import getLogger from collections import OrderedDict import numpy -from pyquickhelper.pycode import ExtTestCase +from pyquickhelper.pycode import ExtTestCase, ignore_warnings from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 OnnxIf, OnnxConstant, OnnxGreater) from skl2onnx.common.data_types import FloatTensorType @@ -20,20 +20,27 @@ def setUp(self): logger = getLogger('skl2onnx') logger.disabled = True + @ignore_warnings(DeprecationWarning) def test_if(self): tensor_type = FloatTensorType op_version = get_opset_number_from_onnx() - bthen = OnnxConstant(value_floats=numpy.array([0], dtype=numpy.float32), - op_version=op_version, output_names=['res']) - belse = OnnxConstant(value_floats=numpy.array([1], dtype=numpy.float32), - op_version=op_version, output_names=['res']) + bthen = OnnxConstant( + value_floats=numpy.array([0], dtype=numpy.float32), + op_version=op_version, output_names=['res_then']) + bthen.set_onnx_name_prefix('then') + + belse = OnnxConstant( + value_floats=numpy.array([1], dtype=numpy.float32), + op_version=op_version, output_names=['res_else']) + belse.set_onnx_name_prefix('else') + bthen_body = bthen.to_onnx( - OrderedDict(), outputs=[('res', tensor_type())], + OrderedDict(), outputs=[('res_then', tensor_type())], target_opset=op_version) belse_body = belse.to_onnx( OrderedDict(), - outputs=[('res', tensor_type())], + outputs=[('res_else', tensor_type())], target_opset=op_version) onx = OnnxIf(OnnxGreater('X', numpy.array([0], dtype=numpy.float32), diff --git a/_unittests/ut_tools/test_optim_onnx_identity.py b/_unittests/ut_tools/test_optim_onnx_identity.py index 02fdbb66c..4c12572c6 100644 --- a/_unittests/ut_tools/test_optim_onnx_identity.py +++ b/_unittests/ut_tools/test_optim_onnx_identity.py @@ -142,7 +142,7 @@ def onnx_test_knn_single_regressor(self, dtype, n_targets=1, debug=False, self.assertIn('subgraphs_optim', stats) def test_onnx_test_knn_single_regressor32(self): - self.onnx_test_knn_single_regressor(numpy.float32, expected=[1, 1]) + self.onnx_test_knn_single_regressor(numpy.float32, expected=[2, 1]) if __name__ == "__main__": diff --git a/mlprodict/npy/numpy_onnx_impl.py b/mlprodict/npy/numpy_onnx_impl.py index 55e8f14e7..44d6727a6 100644 --- a/mlprodict/npy/numpy_onnx_impl.py +++ b/mlprodict/npy/numpy_onnx_impl.py @@ -357,10 +357,14 @@ def onnx_if(condition, then_branch, else_branch): :param else_branch: else branch, of type @see cl if_then_else :return: result (@see cl OnnxVar) """ + if isinstance(then_branch, numpy.ndarray): + then_branch = if_then_else(then_branch) if not isinstance(then_branch, if_then_else): raise TypeError( "Parameter then_branch is not of type " "'if_then_else' but %r." % type(then_branch)) + if isinstance(else_branch, numpy.ndarray): + else_branch = if_then_else(else_branch) if not isinstance(else_branch, if_then_else): raise TypeError( "Parameter then_branch is not of type " diff --git a/mlprodict/npy/numpy_onnx_impl_body.py b/mlprodict/npy/numpy_onnx_impl_body.py index ddbd62422..3ac7f5a32 100644 --- a/mlprodict/npy/numpy_onnx_impl_body.py +++ b/mlprodict/npy/numpy_onnx_impl_body.py @@ -6,6 +6,8 @@ """ import numpy from skl2onnx.common.data_types import FloatTensorType +from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 + OnnxIdentity) from .onnx_variable import OnnxVar @@ -22,6 +24,11 @@ class AttributeGraph: """ def __init__(self, fct, *inputs): + if isinstance(fct, numpy.ndarray) and len(inputs) == 0: + self.cst = fct + fct = None + else: + self.cst = None self.fct = fct self.inputs = inputs self.alg_ = None @@ -60,6 +67,11 @@ def to_algebra(self, op_version=None): if self.alg_ is not None: return self.alg_ + if self.cst is not None: + self.alg_ = OnnxIdentity(self.cst, op_version=op_version) + self.alg_inputs_ = None + return self.alg_ + new_inputs = [self._graph_guess_dtype(i, inp) for i, inp in enumerate(self.inputs)] self.alg_inputs_ = new_inputs @@ -107,18 +119,23 @@ def to_algebra(self, op_version=None): # ONNX graph. updates = dict() self.alg_hidden_var_ = {} + self.alg_hidden_var_inputs = {} for att, var in self.onnx_op_kwargs.items(): if not isinstance(var, AttributeGraph): continue alg = var.to_algebra(op_version=op_version) - onnx_inputs = [i[0] for i in var.alg_inputs_] + alg.set_onnx_name_prefix("g_%s_%d" % (att, id(var))) + if var.alg_inputs_ is None: + onnx_inputs = [] + else: + onnx_inputs = [i[0] for i in var.alg_inputs_] onx = alg.to_onnx(onnx_inputs, target_opset=op_version) updates[att] = onx.graph self.alg_hidden_var_[id(var)] = var + self.alg_hidden_var_inputs[id(var)] = onnx_inputs self.onnx_op_kwargs_before = { k: self.onnx_op_kwargs[k] for k in updates} self.onnx_op_kwargs.update(updates) - return OnnxVar.to_algebra(self, op_version=op_version) diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py index 8ef85728a..f3c206c0a 100644 --- a/mlprodict/npy/onnx_numpy_compiler.py +++ b/mlprodict/npy/onnx_numpy_compiler.py @@ -325,8 +325,10 @@ def _find_hidden_algebras(self, onx_var, onx_algebra): var = stack.pop() hidden = getattr(var, 'alg_hidden_var_', None) if hidden is not None: - keep_hidden.update(hidden) - var_graphs.append(var) + if any(map(lambda x: len(x) > 0, + var.alg_hidden_var_inputs.values())): + keep_hidden.update(hidden) + var_graphs.append(var) if hasattr(var, 'inputs'): for inp in var.inputs: stack.append(inp) @@ -368,15 +370,16 @@ def _to_onnx(self, op_version=None, signature=None, version=None): hidden_algebras, var_graphs = self._find_hidden_algebras( onx_var, onx_algebra) if len(hidden_algebras) > 0: - for gr in var_graphs: - print(type(gr), dir(gr)) - for k, v in hidden_algebras.items(): - print("*", type(v.alg_), dir(v.alg_)) - import pprint - pprint.pprint(dir(v.alg_)) - + # for gr in var_graphs: + # print(type(gr), dir(gr)) + # for k, v in hidden_algebras.items(): + # print("*", type(v.alg_), dir(v.alg_)) + # import pprint + # #pprint.pprint(dir(v.alg_)) raise NotImplementedError( - "Not implemented yet.") + "Subgraph only supports constants (operator If, Loop, " + "Scan). hidden_algebras=%r var_graphs=%r" % ( + hidden_algebras, var_graphs)) if isinstance(onx_algebra, str): raise RuntimeError( # pragma: no cover From 54a648ce0426c1878996f8615d05449b54aaae80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 26 Nov 2021 10:43:53 +0100 Subject: [PATCH 10/10] fix unit test --- _unittests/ut_tools/test_model_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_tools/test_model_info.py b/_unittests/ut_tools/test_model_info.py index 4aec64dc4..2705bac86 100644 --- a/_unittests/ut_tools/test_model_info.py +++ b/_unittests/ut_tools/test_model_info.py @@ -164,7 +164,7 @@ def test_knnc_onnx(self): onx = to_onnx(model, numpy.zeros((3, 4), dtype=numpy.float32)) info = analyze_model(onx) self.assertIn('op_Identity', info) - self.assertEqual(info['op_Identity'], 1) + self.assertEqual(info['op_Identity'], 2) @skipif_circleci('issue, too long') def test_gbc(self):