Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions _doc/sphinxdoc/source/tutorial/numpy_api_onnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -840,3 +840,62 @@ converter after the call. Another way is to call function
:func:`register_rewritten_operators
<mlprodict.onnx_conv.register_rewritten_converters.register_rewritten_operators>`
but changes are permanent.

Issue when an estimator is called by another one
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

A new class is created and the method *transform* is implemented
with the numpy API for ONNX. This function must produce an ONNX
graph including the embedded the embedded model. It must call
the converter for this estimator to get that graph.
That what instruction ``nxnpskl.transformer(X, model=self.estimator_)``
does. However it produces the following error.


.. runpython::
:showcode:

import numpy
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.preprocessing import StandardScaler
from mlprodict.onnx_conv import to_onnx
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy import onnxsklearn_class
import mlprodict.npy.numpy_onnx_impl_skl as nxnpskl


@onnxsklearn_class("onnx_graph")
class CustomTransformerOnnx(TransformerMixin, BaseEstimator):

def __init__(self, base_estimator):
TransformerMixin.__init__(self)
BaseEstimator.__init__(self)
self.base_estimator = base_estimator

def fit(self, X, y, sample_weights=None):
if sample_weights is not None:
raise NotImplementedError(
"weighted sample not implemented in this example.")

self.estimator_ = self.base_estimator.fit( # pylint: disable=W0201
X, y, sample_weights)
return self

def onnx_graph(self, X):
return nxnpskl.transformer(X, model=self.estimator_)


X = numpy.random.randn(20, 2).astype(numpy.float32)
y = ((X.sum(axis=1) + numpy.random.randn(
X.shape[0]).astype(numpy.float32)) >= 0).astype(numpy.int64)
dec = CustomTransformerOnnx(StandardScaler())
dec.fit(X, y)
onx = to_onnx(dec, X.astype(numpy.float32))
oinf = OnnxInference(onx)
tr = dec.transform(X) # pylint: disable=E1101
got = oinf.run({'X': X})
print(got)

To fix it, instruction ``return nxnpskl.transformer(X, model=self.estimator_)``
should be replaced by
``return nxnpskl.transformer(X, model=self.estimator_).copy()``.
2 changes: 1 addition & 1 deletion _unittests/ut__skl2onnx/test_sklearn_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
@brief test tree node (time=10s)
@brief test tree node (time=3s)
"""
import unittest
import warnings
Expand Down
220 changes: 220 additions & 0 deletions _unittests/ut_npy/test_custom_embedded_any_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# -*- coding: utf-8 -*-
"""
@brief test log(time=3s)
"""
import unittest
from logging import getLogger
import numpy
from sklearn.base import (
ClassifierMixin, RegressorMixin, ClusterMixin,
TransformerMixin, BaseEstimator)
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from pyquickhelper.pycode import ExtTestCase, ignore_warnings
from mlprodict.onnx_conv import to_onnx, register_rewritten_operators
from mlprodict.onnxrt import OnnxInference
from mlprodict.npy import onnxsklearn_class
from mlprodict.npy.onnx_variable import MultiOnnxVar
# import mlprodict.npy.numpy_onnx_impl as nxnp
import mlprodict.npy.numpy_onnx_impl_skl as nxnpskl


@onnxsklearn_class("onnx_graph")
class AnyCustomClassifierOnnx(ClassifierMixin, BaseEstimator):

def __init__(self, base_estimator):
ClassifierMixin.__init__(self)
BaseEstimator.__init__(self)
self.base_estimator = base_estimator

def fit(self, X, y, sample_weights=None):
if sample_weights is not None:
raise NotImplementedError(
"weighted sample not implemented in this example.")

self.estimator_ = self.base_estimator.fit( # pylint: disable=W0201
X, y, sample_weights)
self.classes_ = self.estimator_.classes_ # pylint: disable=W0201
return self

def onnx_graph(self, X):
res_model = nxnpskl.classifier(X, model=self.estimator_)
label = res_model[0].copy()
prob = res_model[1].copy()
return MultiOnnxVar(label, prob)


@onnxsklearn_class("onnx_graph")
class AnyCustomRegressorOnnx(RegressorMixin, BaseEstimator):

def __init__(self, base_estimator):
RegressorMixin.__init__(self)
BaseEstimator.__init__(self)
self.base_estimator = base_estimator

def fit(self, X, y, sample_weights=None):
if sample_weights is not None:
raise NotImplementedError(
"weighted sample not implemented in this example.")

self.estimator_ = self.base_estimator.fit( # pylint: disable=W0201
X, y, sample_weights)
return self

def onnx_graph(self, X):
return nxnpskl.regressor(X, model=self.estimator_).copy()


@onnxsklearn_class("onnx_graph")
class AnyCustomClusterOnnx(ClusterMixin, BaseEstimator):

def __init__(self, base_estimator):
ClusterMixin.__init__(self)
BaseEstimator.__init__(self)
self.base_estimator = base_estimator

def fit(self, X, y, sample_weights=None):
if sample_weights is not None:
raise NotImplementedError(
"weighted sample not implemented in this example.")

self.estimator_ = self.base_estimator.fit( # pylint: disable=W0201
X, y, sample_weights)
return self

def onnx_graph(self, X):
res_model = nxnpskl.cluster(X, model=self.estimator_)
label = res_model[0].copy()
prob = res_model[1].copy()
return MultiOnnxVar(label, prob)


@onnxsklearn_class("onnx_graph")
class AnyCustomTransformerOnnx(TransformerMixin, BaseEstimator):

def __init__(self, base_estimator):
TransformerMixin.__init__(self)
BaseEstimator.__init__(self)
self.base_estimator = base_estimator

def fit(self, X, y, sample_weights=None):
if sample_weights is not None:
raise NotImplementedError(
"weighted sample not implemented in this example.")

self.estimator_ = self.base_estimator.fit( # pylint: disable=W0201
X, y, sample_weights)
return self

def onnx_graph(self, X):
return nxnpskl.transformer(X, model=self.estimator_).copy()


class TestCustomEmbeddedModels(ExtTestCase):

def setUp(self):
logger = getLogger('skl2onnx')
logger.disabled = True
register_rewritten_operators()

def common_test_function_classifier_embedded(self, dtype, est):
X = numpy.random.randn(20, 2).astype(dtype)
y = ((X.sum(axis=1) + numpy.random.randn(
X.shape[0]).astype(numpy.float32)) >= 0).astype(numpy.int64)
dec = AnyCustomClassifierOnnx(est)
dec.fit(X, y)
onx = to_onnx(dec, X.astype(dtype),
options={id(dec): {'zipmap': False}})
oinf = OnnxInference(onx)
exp = dec.predict(X) # pylint: disable=E1101
prob = dec.predict_proba(X) # pylint: disable=E1101
got = oinf.run({'X': X})
self.assertEqual(dtype, prob.dtype)
self.assertEqualArray(exp, got['label'].ravel())
self.assertEqualArray(prob, got['probabilities'])

@ignore_warnings((DeprecationWarning, RuntimeWarning))
def test_function_classifier_embedded_float32(self):
self.common_test_function_classifier_embedded(
numpy.float32, DecisionTreeClassifier(max_depth=3))

@ignore_warnings((DeprecationWarning, RuntimeWarning))
def test_function_classifier_embedded_float64(self):
self.common_test_function_classifier_embedded(
numpy.float64, DecisionTreeClassifier(max_depth=3))

def common_test_function_regressor_embedded(self, dtype, est):
X = numpy.random.randn(40, 2).astype(dtype)
y = (X.sum(axis=1) + numpy.random.randn(
X.shape[0])).astype(numpy.float32)
dec = AnyCustomRegressorOnnx(est)
dec.fit(X, y)
onx = to_onnx(dec, X.astype(dtype))
oinf = OnnxInference(onx)
exp = dec.predict(X) # pylint: disable=E1101
got = oinf.run({'X': X})
self.assertEqual(dtype, exp.dtype)
self.assertEqualArray(exp, got['variable'])

@ignore_warnings((DeprecationWarning, RuntimeWarning))
def test_function_regressor_embedded_float32(self):
self.common_test_function_regressor_embedded(
numpy.float32, DecisionTreeRegressor(max_depth=3))

@ignore_warnings((DeprecationWarning, RuntimeWarning))
def test_function_regressor_embedded_float64(self):
self.common_test_function_regressor_embedded(
numpy.float64, DecisionTreeRegressor(max_depth=3))

def common_test_function_cluster_embedded(self, dtype, est):
X = numpy.random.randn(20, 2).astype(dtype)
y = ((X.sum(axis=1) + numpy.random.randn(
X.shape[0]).astype(numpy.float32)) >= 0).astype(numpy.int64)
dec = AnyCustomClusterOnnx(est)
dec.fit(X, y)
onx = to_onnx(dec, X.astype(dtype))
oinf = OnnxInference(onx)
exp = dec.predict(X) # pylint: disable=E1101
prob = dec.transform(X) # pylint: disable=E1101
got = oinf.run({'X': X})
self.assertEqual(dtype, prob.dtype)
self.assertEqualArray(exp, got['label'].ravel())
self.assertEqualArray(prob, got['scores'])

@ignore_warnings((DeprecationWarning, RuntimeWarning))
def test_function_cluster_embedded_float32(self):
self.common_test_function_cluster_embedded(
numpy.float32, KMeans(n_clusters=2))

@ignore_warnings((DeprecationWarning, RuntimeWarning))
def test_function_cluster_embedded_float64(self):
self.common_test_function_cluster_embedded(
numpy.float64, KMeans(n_clusters=2))

def common_test_function_transformer_embedded(self, dtype, est):
X = numpy.random.randn(20, 2).astype(dtype)
y = ((X.sum(axis=1) + numpy.random.randn(
X.shape[0]).astype(numpy.float32)) >= 0).astype(numpy.int64)
dec = AnyCustomTransformerOnnx(est)
dec.fit(X, y)
onx = to_onnx(dec, X.astype(dtype))
oinf = OnnxInference(onx)
tr = dec.transform(X) # pylint: disable=E1101
got = oinf.run({'X': X})
self.assertEqual(dtype, tr.dtype)
self.assertEqualArray(tr, got['variable'])

@ignore_warnings((DeprecationWarning, RuntimeWarning))
def test_function_transformer_embedded_float32(self):
self.common_test_function_transformer_embedded(
numpy.float32, StandardScaler())

@ignore_warnings((DeprecationWarning, RuntimeWarning))
def test_function_transformer_embedded_float64(self):
self.common_test_function_transformer_embedded(
numpy.float64, StandardScaler())


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def onnx_graph(self, X):
return pred


class TestCustomEmbeddedModels(ExtTestCase):
class TestCustomEmbeddedLinearModels(ExtTestCase):

def setUp(self):
logger = getLogger('skl2onnx')
Expand Down
5 changes: 4 additions & 1 deletion _unittests/ut_onnxrt/test_bugs_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from skl2onnx import convert_sklearn
from skl2onnx import __version__ as skl2onnx_version
from skl2onnx.common.data_types import FloatTensorType, StringTensorType
from mlprodict.onnxrt import OnnxInference
from mlprodict.onnxrt.validate.data import load_audit
Expand Down Expand Up @@ -83,7 +84,9 @@ def convert_dataframe_schema(df, drop=None):
try:
model_onnx = convert_sklearn(predictor, model_name, inputs)
except Exception as e:
raise e
raise AssertionError(
"Unable to convert model %r (version=%r)." % (
predictor, skl2onnx_version)) from e

data = {col[0]: x_test[col[0]].values.reshape(x_test.shape[0], 1)
for col in inputs}
Expand Down
63 changes: 63 additions & 0 deletions mlprodict/npy/numpy_onnx_impl_skl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,66 @@ def logistic_regression(x, *, model=None):
"""
return MultiOnnxVar(model, x, op=OnnxSubEstimator,
options={'zipmap': False})


def classifier(x, *, model=None):
"""
Returns any classifier from :epkg:`scikit-learn`
converted into ONNX assuming a converter is registered
with :epkg:`sklearn-onnx`. Option *zipmap* is set to false.

:param x: array, variable name, instance of :class:`OnnxVar
<mlprodict.npy.onnx_variable.OnnxVar>`
:param model: instance of a classifier
:return: instance of :class:`MultiOnnxVar
<mlprodict.npy.onnx_variable.MultiOnnxVar>`, first
output is labels, second one is the probabilities
"""
return MultiOnnxVar(model, x, op=OnnxSubEstimator,
options={'zipmap': False})


def cluster(x, *, model=None):
"""
Returns any cluster from :epkg:`scikit-learn`
converted into ONNX assuming a converter is registered
with :epkg:`sklearn-onnx`. Option *zipmap* is set to false.

:param x: array, variable name, instance of :class:`OnnxVar
<mlprodict.npy.onnx_variable.OnnxVar>`
:param model: instance of a cluster
:return: instance of :class:`MultiOnnxVar
<mlprodict.npy.onnx_variable.MultiOnnxVar>`, first
output is labels, second one is the probabilities
"""
return MultiOnnxVar(model, x, op=OnnxSubEstimator)


def regressor(x, *, model=None):
"""
Returns any regressor from :epkg:`scikit-learn`
converted into ONNX assuming a converter is registered
with :epkg:`sklearn-onnx`.

:param x: array, variable name, instance of :class:`OnnxVar
<mlprodict.npy.onnx_variable.OnnxVar>`
:param model: instance of a regressor
:return: instance of :class:`OnnxVar
<mlprodict.npy.onnx_variable.OnnxVar>`
"""
return OnnxVar(model, x, op=OnnxSubEstimator)


def transformer(x, *, model=None):
"""
Returns any transformer from :epkg:`scikit-learn`
converted into ONNX assuming a converter is registered
with :epkg:`sklearn-onnx`.

:param x: array, variable name, instance of :class:`OnnxVar
<mlprodict.npy.onnx_variable.OnnxVar>`
:param model: instance of a transformer
:return: instance of :class:`OnnxVar
<mlprodict.npy.onnx_variable.OnnxVar>`
"""
return OnnxVar(model, x, op=OnnxSubEstimator)