From 4b82a43ff23fb4b6c4bf00d5616e44c0267f4708 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 7 Nov 2020 12:36:36 +0100 Subject: [PATCH] Fixes #183, fix missing parameter black_op in OnnxPipeline --- _unittests/ut_sklapi/test_onnx_pipeline.py | 9 +++++++++ mlprodict/sklapi/onnx_pipeline.py | 6 ++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_sklapi/test_onnx_pipeline.py b/_unittests/ut_sklapi/test_onnx_pipeline.py index a4e688b02..a7256853c 100644 --- a/_unittests/ut_sklapi/test_onnx_pipeline.py +++ b/_unittests/ut_sklapi/test_onnx_pipeline.py @@ -10,6 +10,7 @@ from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression from sklearn.mixture import GaussianMixture +from sklearn.tree import DecisionTreeRegressor from pyquickhelper.pycode import ExtTestCase, ignore_warnings from mlinsights.mlmodel import TransferTransformer from mlprodict.onnx_conv import to_onnx @@ -46,6 +47,13 @@ def test_pipeline_iris(self): self.assertEqualArray(res["label"], pipe.predict(X)) self.assertEqualArray(res["probabilities"], pipe.predict_proba(X)) + def test_pipeline_none_params(self): + model_onx = OnnxPipeline([ + ('scaler', StandardScaler()), + ('dt', DecisionTreeRegressor(max_depth=2)) + ]) + self.assertNotEmpty(model_onx) + def test_pipeline_iris_enfore_false(self): iris = load_iris() X, y = iris.data, iris.target @@ -235,4 +243,5 @@ def cache(self, obj): if __name__ == '__main__': + TestOnnxPipeline().test_pipeline_none_params() unittest.main() diff --git a/mlprodict/sklapi/onnx_pipeline.py b/mlprodict/sklapi/onnx_pipeline.py index 5ddcb9248..92be41d6a 100644 --- a/mlprodict/sklapi/onnx_pipeline.py +++ b/mlprodict/sklapi/onnx_pipeline.py @@ -66,8 +66,6 @@ def __init__(self, steps, *, memory=None, verbose=False, runtime='python', options=None, white_op=None, black_op=None, final_types=None, op_version=None): - Pipeline.__init__( - self, steps, memory=memory, verbose=verbose) self.output_name = output_name self.enforce_float32 = enforce_float32 self.runtime = runtime @@ -77,6 +75,10 @@ def __init__(self, steps, *, memory=None, verbose=False, self.black_op = black_op self.final_types = final_types self.op_version = op_version + # The constructor calls _validate_step and it checks the value + # of black_op. + Pipeline.__init__( + self, steps, memory=memory, verbose=verbose) def fit(self, X, y=None, **fit_params): """