diff --git a/_doc/sphinxdoc/source/api/npy.rst b/_doc/sphinxdoc/source/api/npy.rst index 63cf09a07..9ced55726 100644 --- a/_doc/sphinxdoc/source/api/npy.rst +++ b/_doc/sphinxdoc/source/api/npy.rst @@ -53,19 +53,35 @@ as opposed to numpy. This approach is similar to what :epkg:`tensorflow` with `autograph `_. -NDArray -+++++++ +Signatures +++++++++++ .. autosignature:: mlprodict.npy.onnx_numpy_annotation.NDArray :members: -onnxnumpy -+++++++++ +.. autosignature:: mlprodict.npy.onnx_numpy_annotation.NDArraySameType + :members: + +.. autosignature:: mlprodict.npy.onnx_numpy_annotation.NDArraySameTypeSameShape + :members: + +.. autosignature:: mlprodict.npy.onnx_numpy_annotation.NDArrayType + :members: + +.. autosignature:: mlprodict.npy.onnx_numpy_annotation.NDArrayTypeSameShape + :members: + +Decorators +++++++++++ .. autosignature:: mlprodict.npy.onnx_numpy_wrapper.onnxnumpy .. autosignature:: mlprodict.npy.onnx_numpy_wrapper.onnxnumpy_default +.. autosignature:: mlprodict.npy.onnx_numpy_wrapper.onnxnumpy_np + +.. autosignature:: mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_transformer + OnnxNumpyCompiler +++++++++++++++++ @@ -78,6 +94,11 @@ OnnxVar .. autosignature:: mlprodict.npy.onnx_variable.OnnxVar :members: +Registration +++++++++++++ + +.. autosignature:: mlprodict.npy.onnx_sklearn_wrapper.update_registered_converter_npy + .. _l-numpy-onnxpy-list-fct: Available numpy functions implemented with ONNX operators diff --git a/_unittests/ut_npy/test_complex_scenario.py b/_unittests/ut_npy/test_complex_scenario.py index e6607646a..db700aa5a 100644 --- a/_unittests/ut_npy/test_complex_scenario.py +++ b/_unittests/ut_npy/test_complex_scenario.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -@brief test log(time=3s) +@brief test log(time=21s) """ import unittest import warnings diff --git a/_unittests/ut_npy/test_custom_transformer.py b/_unittests/ut_npy/test_custom_transformer.py new file mode 100644 index 000000000..f2b57f89f --- /dev/null +++ b/_unittests/ut_npy/test_custom_transformer.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +""" +@brief test log(time=3s) +""" +import unittest +import warnings +from logging import getLogger +import numpy +from sklearn.base import TransformerMixin, BaseEstimator +from sklearn.decomposition import PCA +from pyquickhelper.pycode import ExtTestCase, ignore_warnings +from skl2onnx import update_registered_converter +from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 + OnnxIdentity, OnnxMatMul, OnnxSub) +from skl2onnx.algebra.onnx_operator import OnnxSubOperator +from skl2onnx.common.data_types import guess_numpy_type +from mlprodict.onnx_conv import to_onnx +from mlprodict.onnxrt import OnnxInference +from mlprodict.npy import onnxsklearn_transformer + + +class DecorrelateTransformer(TransformerMixin, BaseEstimator): + def __init__(self, alpha=0.): + BaseEstimator.__init__(self) + TransformerMixin.__init__(self) + self.alpha = alpha + + def fit(self, X, y=None, sample_weights=None): + self.pca_ = PCA(X.shape[1]) # pylint: disable=W0201 + self.pca_.fit(X) + return self + + def transform(self, X): + return self.pca_.transform(X) + + +def decorrelate_transformer_shape_calculator(operator): + op = operator.raw_operator + input_type = operator.inputs[0].type.__class__ + input_dim = operator.inputs[0].type.shape[0] + output_type = input_type([input_dim, op.pca_.components_.shape[1]]) + operator.outputs[0].type = output_type + + +def decorrelate_transformer_converter(scope, operator, container): + op = operator.raw_operator + opv = container.target_opset + out = operator.outputs + X = operator.inputs[0] + subop = OnnxSubOperator(op.pca_, X, op_version=opv) + Y = OnnxIdentity(subop, op_version=opv, output_names=out[:1]) + Y.add_to(scope, container) + + +class DecorrelateTransformer2(DecorrelateTransformer): + pass + + +def decorrelate_transformer_converter2(scope, operator, container): + op = operator.raw_operator + opv = container.target_opset + out = operator.outputs + X = operator.inputs[0] + dtype = guess_numpy_type(X.type) + m = OnnxMatMul( + OnnxSub(X, op.pca_.mean_.astype(dtype), op_version=opv), + op.pca_.components_.T.astype(dtype), op_version=opv) + Y = OnnxIdentity(m, op_version=opv, output_names=out[:1]) + Y.add_to(scope, container) + + +class DecorrelateTransformer3(DecorrelateTransformer): + pass + + +@onnxsklearn_transformer(register_class=DecorrelateTransformer3) +def decorrelate_transformer_converter3(X, op=None): + if X.dtype is None: + raise AssertionError("X.dtype cannot be None.") + mean = op.pca_.mean_.astype(X.dtype) + cmp = op.pca_.components_.T.astype(X.dtype) + return (X - mean) @ cmp + + +class TestCustomTransformer(ExtTestCase): + + def setUp(self): + logger = getLogger('skl2onnx') + logger.disabled = True + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ResourceWarning) + update_registered_converter( + DecorrelateTransformer, "SklearnDecorrelateTransformer", + decorrelate_transformer_shape_calculator, + decorrelate_transformer_converter) + update_registered_converter( + DecorrelateTransformer2, "SklearnDecorrelateTransformer2", + decorrelate_transformer_shape_calculator, + decorrelate_transformer_converter2) + + @ignore_warnings((DeprecationWarning, RuntimeWarning)) + def test_function_transformer(self): + X = numpy.random.randn(20, 2).astype(numpy.float32) + dec = DecorrelateTransformer() + dec.fit(X) + onx = to_onnx(dec, X.astype(numpy.float32)) + oinf = OnnxInference(onx) + exp = dec.transform(X) + got = oinf.run({'X': X}) + self.assertEqualArray(exp, got['variable']) + + @ignore_warnings((DeprecationWarning, RuntimeWarning)) + def test_function_transformer2(self): + X = numpy.random.randn(20, 2).astype(numpy.float32) + dec = DecorrelateTransformer2() + dec.fit(X) + onx = to_onnx(dec, X.astype(numpy.float32)) + oinf = OnnxInference(onx) + exp = dec.transform(X) + got = oinf.run({'X': X}) + self.assertEqualArray(exp, got['variable']) + + @ignore_warnings((DeprecationWarning, RuntimeWarning)) + def test_function_transformer3_float32(self): + X = numpy.random.randn(20, 2).astype(numpy.float32) + dec = DecorrelateTransformer3() + dec.fit(X) + onx = to_onnx(dec, X.astype(numpy.float32)) + oinf = OnnxInference(onx) + exp = dec.transform(X) + got = oinf.run({'X': X}) + self.assertEqualArray(exp, got['variable']) + X2 = decorrelate_transformer_converter3(X, op=dec) + self.assertEqualArray(X2, got['variable']) + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_npy/test_onnxpy.py b/_unittests/ut_npy/test_onnxpy.py index 56fa8940d..5dd5b8254 100644 --- a/_unittests/ut_npy/test_onnxpy.py +++ b/_unittests/ut_npy/test_onnxpy.py @@ -8,6 +8,7 @@ from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument # pylint: disable=E0611 from pyquickhelper.pycode import ExtTestCase from mlprodict.npy import OnnxNumpyCompiler as ONC, NDArray +from mlprodict.npy.onnx_variable import OnnxVar from mlprodict.npy.onnx_numpy_annotation import _NDArrayAlias from skl2onnx.algebra.onnx_ops import OnnxAbs # pylint: disable=E0611 from skl2onnx.common.data_types import FloatTensorType @@ -25,6 +26,17 @@ def onnx_abs_shape(x: NDArray[(Any, Any), numpy.float32], op_version=None) -> NDArray[(Any, Any), numpy.float32]: return OnnxAbs(x, op_version=op_version) + def test_onnx_var(self): + ov = OnnxVar('X') + rp = repr(ov) + self.assertEqual("OnnxVar('X')", rp) + ov = OnnxVar('X', op=OnnxAbs) + rp = repr(ov) + self.assertEqual("OnnxVar('X', op=OnnxAbs)", rp) + ov = OnnxVar('X', op='filter') + rp = repr(ov) + self.assertEqual("OnnxVar('X', op='filter')", rp) + def test_process_dtype(self): for name in ['all', 'int', 'ints', 'floats', 'T']: res = _NDArrayAlias._process_type( # pylint: disable=W0212 diff --git a/_unittests/ut_onnxrt/test_onnxrt_runtime_empty.py b/_unittests/ut_onnxrt/test_onnxrt_runtime_empty.py index fba772852..29262e7b3 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_runtime_empty.py +++ b/_unittests/ut_onnxrt/test_onnxrt_runtime_empty.py @@ -50,7 +50,8 @@ def test_onnxt_runtime_empty_unknown(self): Z = helper.make_tensor_value_info( 'Z', TensorProto.FLOAT, [None, 2]) # pylint: disable=E1101 node_def = helper.make_node('Add', ['X', 'Y'], ['Zt'], name='Zt') - node_def2 = helper.make_node('AddUnknown', ['X', 'Zt'], ['Z'], name='Z') + node_def2 = helper.make_node( + 'AddUnknown', ['X', 'Zt'], ['Z'], name='Z') graph_def = helper.make_graph( [node_def, node_def2], 'test-model', [X, Y], [Z]) model_def = helper.make_model(graph_def, producer_name='onnx-example') diff --git a/mlprodict/npy/__init__.py b/mlprodict/npy/__init__.py index 14acc19e4..524f01d86 100644 --- a/mlprodict/npy/__init__.py +++ b/mlprodict/npy/__init__.py @@ -10,3 +10,5 @@ Shape, DType) from .onnx_numpy_compiler import OnnxNumpyCompiler from .onnx_numpy_wrapper import onnxnumpy, onnxnumpy_default, onnxnumpy_np +from .onnx_sklearn_wrapper import ( + onnxsklearn_transformer, update_registered_converter_npy) diff --git a/mlprodict/npy/onnx_sklearn_wrapper.py b/mlprodict/npy/onnx_sklearn_wrapper.py new file mode 100644 index 000000000..e1cdec6a8 --- /dev/null +++ b/mlprodict/npy/onnx_sklearn_wrapper.py @@ -0,0 +1,178 @@ +""" +@file +@brief Helpers to use numpy API to easily write converters +for :epkg:`scikit-learn` classes for :epkg:`onnx`. + +.. versionadded:: 0.6 +""" +from skl2onnx import update_registered_converter +from skl2onnx.algebra.onnx_ops import OnnxIdentity # pylint: disable=E0611 +from .onnx_variable import OnnxVar +from .onnx_numpy_wrapper import _created_classes_inst, wrapper_onnxnumpy_np +from .onnx_numpy_annotation import NDArraySameType + + +def _shape_calculator_transformer(operator): + """ + Default shape calculator for a transformer with one input + and one output of the same type. + + .. versionadded:: 0.6 + """ + if not hasattr(operator, 'onnx_numpy_fct_'): + raise AttributeError( + "operator must have attribute 'onnx_numpy_fct_'.") + X = operator.inputs + if len(X) != 1: + raise RuntimeError( + "This function only supports one input not %r." % len(X)) + if len(operator.outputs) != 1: + raise RuntimeError( + "This function only supports one output not %r." % len(operator.outputs)) + cl = X[0].type.__class__ + dim = [X[0].type.shape[0], None] + operator.outputs[0] = cl(dim) + + +def _converter_transformer(scope, operator, container): + """ + Default converter for a transformer with one input + and one output of the same type. It assumes instance *operator* + has an attribute *onnx_numpy_fct_* from a function + wrapped with decoarator :func:`onnxsklearn_transformer + `. + + .. versionadded:: 0.6 + """ + if not hasattr(operator, 'onnx_numpy_fct_'): + raise AttributeError( + "operator must have attribute 'onnx_numpy_fct_'.") + X = operator.inputs + if len(X) != 1: + raise RuntimeError( + "This function only supports one input not %r." % len(X)) + if len(operator.outputs) != 1: + raise RuntimeError( + "This function only supports one output not %r." % len(operator.outputs)) + + xvar = OnnxVar(X[0]) + fct_cl = operator.onnx_numpy_fct_ + + opv = container.target_opset + inst = fct_cl.fct(xvar, op=operator.raw_operator) + onx = inst.to_algebra(op_version=opv) + final = OnnxIdentity(onx, op_version=opv, + output_names=[operator.outputs[0].full_name]) + final.add_to(scope, container) + + +def update_registered_converter_npy( + model, alias, convert_fct, shape_fct=None, overwrite=True, + parser=None, options=None): + """ + Registers or updates a converter for a new model so that + it can be converted when inserted in a *scikit-learn* pipeline. + This function assumes the converter is written as a function + decoarated with :func:`onnxsklearn_transformer + `. + + :param model: model class + :param alias: alias used to register the model + :param shape_fct: function which checks or modifies the expected + outputs, this function should be fast so that the whole graph + can be computed followed by the conversion of each model, + parallelized or not + :param convert_fct: function which converts a model + :param overwrite: False to raise exception if a converter + already exists + :param parser: overwrites the parser as well if not empty + :param options: registered options for this converter + + The alias is usually the library name followed by the model name. + + .. versionadded:: 0.6 + """ + if (hasattr(convert_fct, "compiled") or + hasattr(convert_fct, 'signed_compiled')): + # type is wrapper_onnxnumpy or wrapper_onnxnumpy_np + obj = convert_fct + else: + raise AttributeError( + "Class %r must have attribute 'compiled' or 'signed_compiled' " + "(object=%r)." % (type(convert_fct), convert_fct)) + + def addattr(operator, obj): + operator.onnx_numpy_fct_ = obj + return operator + + if shape_fct is None: + local_shape_fct = ( + lambda operator: + _shape_calculator_transformer( + addattr(operator, obj))) + else: + local_shape_fct = shape_fct + + local_convert_fct = ( + lambda scope, operator, container: + _converter_transformer( + scope, addattr(operator, obj), container)) + + update_registered_converter( + model, alias, convert_fct=local_convert_fct, + shape_fct=local_shape_fct, overwrite=overwrite, + parser=parser, options=options) + + +def onnxsklearn_transformer(op_version=None, runtime=None, signature=None, + register_class=None): + """ + Decorator to declare a converter for a transformer implemented using + :epkg:`numpy` syntax but executed with :epkg:`ONNX` + operators. + + :param op_version: :epkg:`ONNX` opset version + :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference + :param signature: if None, the signature is replace by a standard signature + for transformer ``NDArraySameType("all")`` + :param register_class: automatically register this converter + for this class to :epkg:`sklearn-onnx` + + Equivalent to `onnxnumpy(arg)(foo)`. + + .. versionadded:: 0.6 + """ + if signature is None: + signature = NDArraySameType("all") + + def default_shape_calculator(operator): + op = operator.raw_operator + if len(operator.inputs) != 1: + raise NotImplementedError( + "Default shape calculator only supports one input not %r (type=%r)" + "." % (len(operator.inputs), type(op))) + input_type = operator.inputs[0].type.__class__ + input_dim = operator.inputs[0].type.shape[0] + output_type = input_type([input_dim, None]) + operator.outputs[0].type = output_type + + def decorator_fct(fct): + name = "onnxsklearn_parser_%s_%s_%s" % ( + fct.__name__, str(op_version), runtime) + newclass = type( + name, (wrapper_onnxnumpy_np,), { + '__doc__': fct.__doc__, + '__getstate__': wrapper_onnxnumpy_np.__getstate__, + '__setstate__': wrapper_onnxnumpy_np.__setstate__}) + _created_classes_inst.append(name, newclass) + res = newclass( + fct=fct, op_version=op_version, runtime=runtime, + signature=signature) + if register_class is not None: + update_registered_converter_npy( + register_class, "Sklearn%s" % getattr( + register_class, "__name__", "noname"), + res, shape_fct=default_shape_calculator, overwrite=False) + return res + + return decorator_fct diff --git a/mlprodict/npy/onnx_variable.py b/mlprodict/npy/onnx_variable.py index ba3cb8315..c2e1ea26b 100644 --- a/mlprodict/npy/onnx_variable.py +++ b/mlprodict/npy/onnx_variable.py @@ -6,6 +6,8 @@ """ import numpy from onnx.helper import make_tensor +from skl2onnx.common.data_types import guess_numpy_type +from skl2onnx.common._topology import Variable # pylint: disable=E0611,E0001 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 OnnxAdd, OnnxAnd, OnnxCast, OnnxConstantOfShape, @@ -60,7 +62,6 @@ def __init__(self, *inputs, op=None, select_output=None, self.onnx_op = op self.alg_ = None self.onnx_op_kwargs = kwargs - self.dtype = dtype if dtype is not None and (op is not None or len(inputs) != 1): raise RuntimeError( "dtype can only be used if op is None or len(inputs) == 1.") @@ -68,6 +69,61 @@ def __init__(self, *inputs, op=None, select_output=None, if isinstance(inp, type): raise TypeError( "Unexpected type for input %d - %r." % (i, inp)) + self.dtype = self._guess_dtype(dtype) + + def _guess_dtype(self, dtype): + "Guesses dtype when not specified." + if dtype is not None: + return dtype + dtypes = [] + for i, inp in enumerate(self.inputs): + if isinstance(inp, str): + return None + if isinstance(inp, numpy.ndarray): + dtypes.append(inp.dtype) + elif isinstance(inp, Variable): + dt = guess_numpy_type(inp.type) + dtypes.append(dt) + elif isinstance(inp, OnnxVar): + dtypes.append(inp.dtype) + elif isinstance(inp, (numpy.float32, numpy.float64, numpy.int32, + numpy.int64)): + dtypes.append(inp.dtype) + elif isinstance(inp, numpy_str): + dtypes.append(numpy_str) + elif isinstance(inp, numpy_bool): + dtypes.append(numpy_bool) + elif isinstance(inp, int): + dtypes.append(numpy.int64) + elif isinstance(inp, float): + dtypes.append(numpy.float64) + else: + raise TypeError( + "Unexpected type for input %i type=%r." % (i, type(inp))) + dtypes = [_ for _ in dtypes if _ is not None] + unique = set(dtypes) + if len(unique) != 1: + return None + return dtypes[0] + + def __repr__(self): + "usual" + args = [] + for inp in self.inputs: + args.append(repr(inp)) + if self.onnx_op is not None: + if isinstance(self.onnx_op, str): + args.append("op=%r" % self.onnx_op) + else: + args.append("op=%s" % self.onnx_op.__name__) + if self.select_output is not None: + args.append("select_output=%r" % self.select_output) + if self.dtype is not None and self.dtype != self._guess_dtype(None): + args.append("dtype=%r" % self.dtype) + for k, v in sorted(self.onnx_op_kwargs.items()): + args.append("%s=%r" % (k, v)) + res = "%s(%s)" % (self.__class__.__name__, ", ".join(args)) + return res def to_algebra(self, op_version=None): """