From bcd2ecef6298e43568a1eb80a1a6753b38c11567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 27 Jul 2021 10:58:53 +0200 Subject: [PATCH 1/5] Increases code coverage --- _unittests/ut_cli/test_cli_asv2csv.py | 1 + .../test_grammar_sklearn_cov.py | 20 +++++++++ .../ut_onnxrt/test_onnxrt_python_runtime_.py | 12 +++++- ...test_onnxrt_python_runtime_control_loop.py | 17 +++++++- .../ut_tools/test_onnx_micro_runtime.py | 41 ++++++++++++++++++- .../grammar_sklearn/grammar/api_extension.py | 4 +- .../onnx_tools/optim/graph_schema_helper.py | 2 +- mlprodict/onnxrt/onnx_inference.py | 7 +++- mlprodict/onnxrt/ops_cpu/_op.py | 8 ++-- mlprodict/onnxrt/ops_cpu/op_loop.py | 23 +++++++++-- mlprodict/onnxrt/type_object.py | 13 ++++++ mlprodict/tools/ort_wrapper.py | 6 +-- 12 files changed, 137 insertions(+), 17 deletions(-) create mode 100644 _unittests/ut_grammar_sklearn/test_grammar_sklearn_cov.py create mode 100644 mlprodict/onnxrt/type_object.py diff --git a/_unittests/ut_cli/test_cli_asv2csv.py b/_unittests/ut_cli/test_cli_asv2csv.py index 269ad8518..c5893d80b 100644 --- a/_unittests/ut_cli/test_cli_asv2csv.py +++ b/_unittests/ut_cli/test_cli_asv2csv.py @@ -34,6 +34,7 @@ def test_cli_asv2csv(self): self.assertEqual(df.shape, (168, 66)) out = os.path.join(temp, "data.csv") main(args=["asv2csv", "-f", data, "-o", out], fLOG=st.fprint) + main(args=["asv2csv", "-f", data], fLOG=st.fprint) if __name__ == "__main__": diff --git a/_unittests/ut_grammar_sklearn/test_grammar_sklearn_cov.py b/_unittests/ut_grammar_sklearn/test_grammar_sklearn_cov.py new file mode 100644 index 000000000..bff70199c --- /dev/null +++ b/_unittests/ut_grammar_sklearn/test_grammar_sklearn_cov.py @@ -0,0 +1,20 @@ +""" +@brief test log(time=3s) +""" +import unittest +import platform +from pyquickhelper.pycode import ExtTestCase +from mlprodict.grammar_sklearn.grammar.api_extension import AutoType + + +class TestGrammarSklearnCov(ExtTestCase): + + def test_auto_type(self): + at = AutoType() + self.assertRaise(lambda: at.format_value(3), NotImplementedError) + at._format_value_json = lambda v: str(v) + self.assertRaise(lambda: at.format_value(3), TypeError) + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index f72d1399e..84dcb396a 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -34,6 +34,7 @@ OnnxCompress, OnnxConcat, OnnxConv, OnnxConvTranspose, OnnxConstant, OnnxConstant_9, OnnxConstant_11, + OnnxConstant_12, OnnxConstant_13, OnnxConstantOfShape, OnnxCos, OnnxCosh, OnnxCumSum, @@ -108,6 +109,7 @@ QuantizedTensor, QuantizedBiasTensor, test_qlinear_conv) from mlprodict.onnxrt.ops_cpu.op_qlinear_conv_ import ( # pylint: disable=W0611,E0611,E0401 test_qgemm0, test_qgemm1) +from mlprodict.onnxrt.ops_cpu.op_constant import Constant_12, Constant_11, Constant_9 try: numpy_str = numpy.str_ @@ -3908,7 +3910,13 @@ def test_make_constant(self): opset_tests = [ (get_opset_number_from_onnx(), OnnxConstant), - (11, OnnxConstant_11)] + (13, OnnxConstant_13), + (12, OnnxConstant_12), + (11, OnnxConstant_11), + (9, OnnxConstant_9)] + + expected_type = {14: Constant_12, 12: Constant_12, 13: Constant_12, + 11: Constant_11, 9: Constant_9} if (not sys.platform.startswith('win') or compare_module_version(onnx_version, (1, 8, 0)) != 0): @@ -3935,6 +3943,8 @@ def test_make_constant(self): except RuntimeError as e: raise AssertionError( "Unable to load the model:\n{}".format(model_def)) from e + ope = oinf.sequence_[0].ops_ + self.assertIsInstance(ope, expected_type[opset]) got = oinf.run({'X': X}) if opset >= 11: self.assertEqual(list(sorted(got)), [ diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_loop.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_loop.py index c1b3cc58a..a559d0a28 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_loop.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_control_loop.py @@ -11,6 +11,7 @@ from onnx import TensorProto from pyquickhelper.pycode import ExtTestCase, ignore_warnings from mlprodict.onnxrt import OnnxInference +from mlprodict.onnxrt.type_object import SequenceType from mlprodict.tools import get_opset_number_from_onnx @@ -122,10 +123,22 @@ def test_loop(self): for rt in ['onnxruntime1', 'python']: with self.subTest(rt=rt): oinf = OnnxInference(model_def, runtime=rt) - got = oinf.run({ + inputs = { 'trip_count': trip_count, 'cond': cond, - 'seq_empty': seq_empty}) + 'seq_empty': seq_empty} + got = oinf.run(inputs) self.assertEqualArray(expected, got['res']) + if rt == 'python': + siz = oinf.infer_sizes(inputs) + self.assertIsInstance(siz, dict) + typ = oinf.infer_types() + self.assertEqual(typ["trip_count"], numpy.int64) + if 'cond' in typ: + self.assertEqual(typ["cond"], numpy.bool_) + for k, v in typ.items(): + if k in {'trip_count', 'cond'}: + continue + self.assertIsInstance(v, SequenceType) def sequence_insert_reference_implementation( self, sequence, tensor, position=None): diff --git a/_unittests/ut_tools/test_onnx_micro_runtime.py b/_unittests/ut_tools/test_onnx_micro_runtime.py index 3a5de20bd..2c3af593b 100644 --- a/_unittests/ut_tools/test_onnx_micro_runtime.py +++ b/_unittests/ut_tools/test_onnx_micro_runtime.py @@ -5,7 +5,8 @@ import numpy from pyquickhelper.pycode import ExtTestCase from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 - OnnxAdd, OnnxTranspose, OnnxShape, OnnxPow, OnnxMatMul) + OnnxAdd, OnnxTranspose, OnnxShape, OnnxPow, OnnxMatMul, OnnxGemm, + OnnxSqueeze, OnnxUnsqueeze) from mlprodict.tools.onnx_micro_runtime import OnnxMicroRuntime @@ -75,6 +76,44 @@ def test_onnx_micro_runtime_matmul(self): out = rt.run({'X': x}) self.assertEqual(numpy.matmul(x, x), out['Y']) + def test_onnx_micro_runtime_squeeze(self): + opset = 14 # opset=13, 14, ... + x = numpy.array([1, 2, 4, 5]).astype( + numpy.float32).reshape((2, 2, 1)) + cop = OnnxSqueeze('X', numpy.array([2], dtype=numpy.int64), + op_version=opset, output_names=['Y']) + model_def = cop.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + self.assertEqual(numpy.squeeze(x), out['Y']) + + def test_onnx_micro_runtime_unsqueeze(self): + opset = 14 # opset=13, 14, ... + x = numpy.array([1, 2, 4, 5]).astype( + numpy.float32).reshape((2, 2)) + cop = OnnxUnsqueeze('X', numpy.array([2], dtype=numpy.int64), + op_version=opset, output_names=['Y']) + model_def = cop.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + self.assertEqual(x.reshape(2, 2, 1), out['Y']) + + def test_onnx_micro_runtime_gemm(self): + opset = 14 # opset=13, 14, ... + x = numpy.array([1, 2, 4, 5]).astype( + numpy.float32).reshape((2, 2)) + for ta in [0, 1]: + for tb in [0, 1]: + cop = OnnxGemm( + 'X', 'X', 'X', op_version=opset, alpha=1., beta=1., + output_names=['Y'], transA=ta, transB=tb) + model_def = cop.to_onnx({'X': x}, target_opset=opset) + rt = OnnxMicroRuntime(model_def) + out = rt.run({'X': x}) + xa = x.T if ta else x + xb = x.T if tb else x + self.assertEqual(numpy.matmul(xa, xb) + x, out['Y']) + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/grammar_sklearn/grammar/api_extension.py b/mlprodict/grammar_sklearn/grammar/api_extension.py index b5f8b18f3..56ac70330 100644 --- a/mlprodict/grammar_sklearn/grammar/api_extension.py +++ b/mlprodict/grammar_sklearn/grammar/api_extension.py @@ -88,10 +88,10 @@ def format_value(self, value, lang="json", hook=None): if hasattr(self, name): try: return getattr(self, name)(value, hook=hook) - except TypeError as e: # pragma: no cover + except TypeError as e: raise TypeError( "Singature of '{0}' is wrong for type '{1}'".format(name, type(self))) from e else: - raise NotImplementedError( # pragma: no cover + raise NotImplementedError( "No formatting is implemented for lang='{0}' and type='{1}'".format( lang, type(self))) diff --git a/mlprodict/onnx_tools/optim/graph_schema_helper.py b/mlprodict/onnx_tools/optim/graph_schema_helper.py index a099e52af..e30db0a78 100644 --- a/mlprodict/onnx_tools/optim/graph_schema_helper.py +++ b/mlprodict/onnx_tools/optim/graph_schema_helper.py @@ -102,7 +102,7 @@ def get_defined_outputs(outputs, onnx_node, typed_inputs=None, variables=None, if schema is None: ft = DoubleTensorType if dtype == numpy.float64 else FloatTensorType elif len(schema) != 1: - raise ValueError( + raise ValueError( # pragma: no cover "schema should only contain one output not {}.".format(schema)) else: if isinstance(schema, DataType): diff --git a/mlprodict/onnxrt/onnx_inference.py b/mlprodict/onnxrt/onnx_inference.py index 6de26cba7..b4b66b6a6 100644 --- a/mlprodict/onnxrt/onnx_inference.py +++ b/mlprodict/onnxrt/onnx_inference.py @@ -23,6 +23,7 @@ from .onnx_inference_node import OnnxInferenceNode from .onnx_inference_exports import OnnxInferenceExport from .shape_object import ShapeObject +from .type_object import SequenceType class OnnxInference: @@ -1021,7 +1022,11 @@ def _set_type_inference_runtime(self): for k, v in self.inputs_.items(): # The function assumes the first dimension is unknown # and is the batch size. - values[k] = guess_numpy_type_from_string(v['type']['elem']) + if isinstance(v['type']['elem'], dict): + # sequence + values[k] = SequenceType() + else: + values[k] = guess_numpy_type_from_string(v['type']['elem']) for k, v in self.inits_.items(): values[k] = v['value'].dtype last = None diff --git a/mlprodict/onnxrt/ops_cpu/_op.py b/mlprodict/onnxrt/ops_cpu/_op.py index 6512baa4e..125c028d0 100644 --- a/mlprodict/onnxrt/ops_cpu/_op.py +++ b/mlprodict/onnxrt/ops_cpu/_op.py @@ -8,6 +8,7 @@ import onnx import onnx.defs from ..shape_object import ShapeObject +from ..type_object import SequenceType from ._new_ops import OperatorSchema @@ -240,13 +241,14 @@ def infer_types(self, *args, **kwargs): "res must be tuple not {} (operator '{}')".format( type(res), self.__class__.__name__)) for a in res: - if not isinstance(a, numpy.dtype) and a not in { + if not isinstance(a, (numpy.dtype, SequenceType)) and a not in { numpy.int8, numpy.uint8, numpy.float16, numpy.float32, numpy.float64, numpy.int32, numpy.int64, numpy.int16, numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_, - numpy.uint64, bool, str, }: + numpy.uint64, bool, str}: raise TypeError( # pragma: no cover - "Type ({}, {}) is not a numpy type (operator '{}')".format( + "Type ({}, {}) is not a numpy type or a sequence type " + "(operator '{}')".format( a, type(a), self.__class__.__name__)) return res diff --git a/mlprodict/onnxrt/ops_cpu/op_loop.py b/mlprodict/onnxrt/ops_cpu/op_loop.py index ba9b0b7da..53987574b 100644 --- a/mlprodict/onnxrt/ops_cpu/op_loop.py +++ b/mlprodict/onnxrt/ops_cpu/op_loop.py @@ -19,14 +19,15 @@ def __init__(self, onnx_node, desc=None, **options): expected_attributes=Loop.atts, **options) if not hasattr(self.body, 'run'): - raise RuntimeError("Parameter 'body' must have a method 'run', " - "type {}.".format(type(self.body))) + raise RuntimeError( # pragma: no cover + "Parameter 'body' must have a method 'run', " + "type {}.".format(type(self.body))) self._run_meth = (self.body.run_in_scan if hasattr(self.body, 'run_in_scan') else self.body.run) - def _run(self, M, cond, v_initial, *args): # pylint: disable=W0221 + def _run(self, M, cond, v_initial, *args, callback=None): # pylint: disable=W0221 inputs = {name: None for name in self.body.input_names} inputs[self.body.input_names[2]] = v_initial cond_name = self.body.output_names[1] @@ -43,6 +44,8 @@ def _run(self, M, cond, v_initial, *args): # pylint: disable=W0221 for i, o in zip(self.body.input_names[2:], self.body.output_names[1:]): inputs[i] = outputs[o] + if callback is not None: + callback(inputs) it += 1 if it == 0: outputs = {self.body.output_names[1]: cond} @@ -61,3 +64,17 @@ def _infer_shapes(self, M, cond, v_initial, *args): # pylint: disable=W0221 def _infer_types(self, M, cond, v_initial, *args): # pylint: disable=W0221 res = self.body._set_type_inference_runtime() return tuple([res[name] for name in self.body.output_names[1:]]) + + def _infer_sizes(self, M, cond, v_initial, *args): # pylint: disable=W0221 + store = [] + + def callback_(inputs): + res = self.body.infer_sizes(inputs) + store.append(res) + + res = self._run(M, cond, v_initial, *args, callback=callback_) + temp = 0 + for v in store: + for vv in v.values(): + temp += sum(vv.values()) + return (dict(temp=temp), ) + res diff --git a/mlprodict/onnxrt/type_object.py b/mlprodict/onnxrt/type_object.py new file mode 100644 index 000000000..ff57e0ee0 --- /dev/null +++ b/mlprodict/onnxrt/type_object.py @@ -0,0 +1,13 @@ +""" +@file +@brief Type object. +""" + + +class SequenceType: + """ + Represents a sequence type. + Used in @see methd infer_types. + """ + pass + diff --git a/mlprodict/tools/ort_wrapper.py b/mlprodict/tools/ort_wrapper.py index 9f090fe65..e061ab08b 100644 --- a/mlprodict/tools/ort_wrapper.py +++ b/mlprodict/tools/ort_wrapper.py @@ -13,7 +13,7 @@ InferenceSession as OrtInferenceSession, __version__ as onnxrt_version, GraphOptimizationLevel) -except ImportError: +except ImportError: # pragma: no cover SessionOptions = None RunOptions = None OrtInferenceSession = None @@ -27,7 +27,7 @@ InvalidArgument as OrtInvalidArgument, InvalidGraph as OrtInvalidGraph, RuntimeException as OrtRuntimeException) -except ImportError: +except ImportError: # pragma: no cover SessionOptions = None RunOptions = None InferenceSession = None @@ -105,7 +105,7 @@ def prepare_c_profiling(model_onnx, inputs, dest=None): if dest is None: dest = "." if not os.path.exists(dest): - os.makedirs(dest) + os.makedirs(dest) # pragma: no cover dest = os.path.abspath(dest) name = "model.onnx" model_bytes = model_onnx.SerializeToString() From 10dd14a3a6b2dd46d29c8134bc8585d66bf244b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 27 Jul 2021 13:35:56 +0200 Subject: [PATCH 2/5] lint --- _doc/examples/plot_time_tree_ensemble.py | 3 ++- _unittests/ut_grammar_sklearn/test_grammar_sklearn_cov.py | 3 +-- _unittests/ut_tools/test_onnx_micro_runtime.py | 2 +- mlprodict/onnxrt/type_object.py | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/_doc/examples/plot_time_tree_ensemble.py b/_doc/examples/plot_time_tree_ensemble.py index 214a148dc..ad75edb32 100644 --- a/_doc/examples/plot_time_tree_ensemble.py +++ b/_doc/examples/plot_time_tree_ensemble.py @@ -203,7 +203,8 @@ def measure_onnx_runtime(model, xt, repeat=REPEAT, number=NUMBER, ########################################### # Graphs. ax = piv.T.plot(kind="bar") -ax.set_title("Comparison for %d observations and %d features" % X_test.shape) +ax.set_title("Computation time ratio for %d observations and %d features\n" + "lower is better for onnx runtimes" % X_test.shape) plt.savefig('%s.png' % name) ########################################### diff --git a/_unittests/ut_grammar_sklearn/test_grammar_sklearn_cov.py b/_unittests/ut_grammar_sklearn/test_grammar_sklearn_cov.py index bff70199c..a64b0fe7d 100644 --- a/_unittests/ut_grammar_sklearn/test_grammar_sklearn_cov.py +++ b/_unittests/ut_grammar_sklearn/test_grammar_sklearn_cov.py @@ -2,7 +2,6 @@ @brief test log(time=3s) """ import unittest -import platform from pyquickhelper.pycode import ExtTestCase from mlprodict.grammar_sklearn.grammar.api_extension import AutoType @@ -12,7 +11,7 @@ class TestGrammarSklearnCov(ExtTestCase): def test_auto_type(self): at = AutoType() self.assertRaise(lambda: at.format_value(3), NotImplementedError) - at._format_value_json = lambda v: str(v) + at._format_value_json = lambda v: str(v) # pylint: disable=W0212 self.assertRaise(lambda: at.format_value(3), TypeError) diff --git a/_unittests/ut_tools/test_onnx_micro_runtime.py b/_unittests/ut_tools/test_onnx_micro_runtime.py index 2c3af593b..b29ef15cf 100644 --- a/_unittests/ut_tools/test_onnx_micro_runtime.py +++ b/_unittests/ut_tools/test_onnx_micro_runtime.py @@ -96,7 +96,7 @@ def test_onnx_micro_runtime_unsqueeze(self): model_def = cop.to_onnx({'X': x}, target_opset=opset) rt = OnnxMicroRuntime(model_def) out = rt.run({'X': x}) - self.assertEqual(x.reshape(2, 2, 1), out['Y']) + self.assertEqual(x.reshape((2, 2, 1)), out['Y']) def test_onnx_micro_runtime_gemm(self): opset = 14 # opset=13, 14, ... diff --git a/mlprodict/onnxrt/type_object.py b/mlprodict/onnxrt/type_object.py index ff57e0ee0..135544b80 100644 --- a/mlprodict/onnxrt/type_object.py +++ b/mlprodict/onnxrt/type_object.py @@ -10,4 +10,3 @@ class SequenceType: Used in @see methd infer_types. """ pass - From 555f283415332af747db23513c5f3170377f6fb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 27 Jul 2021 19:16:56 +0200 Subject: [PATCH 3/5] Fix issue with constant --- _unittests/ut_onnxrt/test_onnxrt_python_runtime_.py | 2 +- mlprodict/onnxrt/ops_cpu/_op_list.py | 2 +- mlprodict/onnxrt/ops_cpu/op_constant.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index 84dcb396a..4548cbd95 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -3971,5 +3971,5 @@ def test_op_constant(self): if __name__ == "__main__": # Working - # TestOnnxrtPythonRuntime().test_onnxt_runtime_abs() + # TestOnnxrtPythonRuntime().test_make_constant() unittest.main() diff --git a/mlprodict/onnxrt/ops_cpu/_op_list.py b/mlprodict/onnxrt/ops_cpu/_op_list.py index b91e1e4ee..6bceac841 100644 --- a/mlprodict/onnxrt/ops_cpu/_op_list.py +++ b/mlprodict/onnxrt/ops_cpu/_op_list.py @@ -30,7 +30,7 @@ from .op_concat_from_sequence import ConcatFromSequence from .op_conv import Conv from .op_conv_transpose import ConvTranspose -from .op_constant import Constant +from .op_constant import Constant, Constant_12, Constant_11, Constant_9 from .op_constant_of_shape import ConstantOfShape from .op_cos import Cos from .op_cosh import Cosh diff --git a/mlprodict/onnxrt/ops_cpu/op_constant.py b/mlprodict/onnxrt/ops_cpu/op_constant.py index 60c3f2956..16119b31f 100644 --- a/mlprodict/onnxrt/ops_cpu/op_constant.py +++ b/mlprodict/onnxrt/ops_cpu/op_constant.py @@ -28,7 +28,7 @@ class Constant_9(OpRun): def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, - expected_attributes=Constant.atts, + expected_attributes=Constant_9.atts, **options) self.cst = self.value _check_dtype(self.cst) @@ -56,9 +56,9 @@ class Constant_11(OpRun): def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, - expected_attributes=Constant.atts, + expected_attributes=Constant_11.atts, **options) - if self.sparse_value is not None: + if getattr(self, 'sparse_value', None) is not None: self.cst = self.sparse_value else: self.cst = self.value @@ -94,7 +94,7 @@ class Constant_12(OpRun): def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, - expected_attributes=Constant.atts, + expected_attributes=Constant_12.atts, **options) if hasattr(self, 'sparse_value') and self.sparse_value is not None: self.cst = self.sparse_value From 99d1d9aa12d4269ed05b603185cd2529ba590a86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 28 Jul 2021 10:03:53 +0200 Subject: [PATCH 4/5] Update test_cli_validate.py --- _unittests/ut_cli/test_cli_validate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/_unittests/ut_cli/test_cli_validate.py b/_unittests/ut_cli/test_cli_validate.py index 4877dc9e6..b9bdb1a4c 100644 --- a/_unittests/ut_cli/test_cli_validate.py +++ b/_unittests/ut_cli/test_cli_validate.py @@ -2,6 +2,7 @@ @brief test tree node (time=15s) """ import os +import sys import unittest from pyquickhelper.loghelper import BufferedPrint from pyquickhelper.pycode import ExtTestCase, get_temp_folder, skipif_circleci @@ -84,6 +85,7 @@ def test_cli_validate_model_csv_bug(self): self.assertNotExists(out2) @skipif_circleci('too long') + @unittest.skipIf(sys.platform == 'darwin', reason='stuck') def test_cli_validate_model_lightgbm(self): temp = get_temp_folder(__file__, "temp_validate_model_lgbm_csv") out1 = os.path.join(temp, "raw.csv") From ddd7e0e875c89cab2ab246b68ceeb11542d982e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 28 Jul 2021 12:18:52 +0200 Subject: [PATCH 5/5] minor fixes --- mlprodict/onnxrt/ops_cpu/op_label_encoder.py | 5 +++++ mlprodict/onnxrt/ops_cpu/op_where.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/mlprodict/onnxrt/ops_cpu/op_label_encoder.py b/mlprodict/onnxrt/ops_cpu/op_label_encoder.py index bf0fc778a..0b0d73142 100644 --- a/mlprodict/onnxrt/ops_cpu/op_label_encoder.py +++ b/mlprodict/onnxrt/ops_cpu/op_label_encoder.py @@ -61,6 +61,11 @@ def __init__(self, onnx_node, desc=None, **options): self.keys_floats, self.values_strings)} self.default_ = self.default_string self.dtype_ = numpy.array(self.classes_.values).dtype + elif len(self.keys_int64s) > 0 and len(self.values_strings) > 0: + self.classes_ = {k: v.decode('utf-8') for k, v in zip( + self.keys_int64s, self.values_strings)} + self.default_ = self.default_string + self.dtype_ = numpy.array(self.classes_.values).dtype elif hasattr(self, 'classes_strings'): raise RuntimeError( # pragma: no cover "This runtime does not implement version 1 of " diff --git a/mlprodict/onnxrt/ops_cpu/op_where.py b/mlprodict/onnxrt/ops_cpu/op_where.py index 1a0f9568e..518f52c61 100644 --- a/mlprodict/onnxrt/ops_cpu/op_where.py +++ b/mlprodict/onnxrt/ops_cpu/op_where.py @@ -15,7 +15,7 @@ def __init__(self, onnx_node, desc=None, **options): **options) def _run(self, condition, x, y): # pylint: disable=W0221 - if x.dtype != y.dtype: + if x.dtype != y.dtype and x.dtype not in (numpy.object_, ): raise RuntimeError( # pragma: no cover "x and y should share the same dtype {} != {}".format( x.dtype, y.dtype))