diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml.py index 3fda34185..ee8ddfe7b 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml.py @@ -282,8 +282,11 @@ def test_dict_vectorizer(self): data = [{"amy": 1.0, "chin": 200.0}, {"nice": 3.0, "amy": 1.0}] model.fit_transform(data) exp = model.transform(data) - model_def = convert_sklearn(model, "dictionary vectorizer", - [("input", DictionaryType(StringTensorType([1]), FloatTensorType([1])))]) + model_def = convert_sklearn( + model, "dictionary vectorizer", + [("input", DictionaryType( + StringTensorType([1]), + FloatTensorType([1])))]) oinf = OnnxInference(model_def) array_data = numpy.array(data) got = oinf.run({'input': array_data}) diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml_text.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml_text.py index d999a62d9..55abe52da 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml_text.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml_text.py @@ -18,7 +18,8 @@ from skl2onnx.common.data_types import ( StringTensorType, FloatTensorType, Int64TensorType) from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 - OnnxStringNormalizer, OnnxTfIdfVectorizer, OnnxLabelEncoder) + OnnxStringNormalizer, OnnxTfIdfVectorizer, OnnxLabelEncoder, + OnnxCategoryMapper) from mlprodict.onnx_conv import to_onnx from mlprodict.onnx_conv.onnx_ops import OnnxTokenizer from mlprodict.onnxrt import OnnxInference @@ -473,6 +474,35 @@ def test_multi_output_classifier(self): for e, g in zip(expected_proba, got['probabilities']): self.assertEqualArray(e, g, decimal=5) + def test_onnxrt_category_mapper_intstr(self): + + op = OnnxCategoryMapper( + 'cat', op_version=TARGET_OPSET, + cats_int64s=[1, 2], cats_strings=["cat1", "cat2"], + output_names=['out']) + onx = op.to_onnx( + inputs=[('cat', Int64TensorType())], + outputs=[('out', StringTensorType())]) + oinf = OnnxInference(onx) + res = oinf.run({'cat': numpy.array([1, 2, 1, 5], dtype=numpy.int64)}) + self.assertEqual( + res['out'].tolist(), ["cat1", "cat2", "cat1", ""]) + + def test_onnxrt_category_mapper_strint(self): + + op = OnnxCategoryMapper( + 'cat', op_version=TARGET_OPSET, + cats_int64s=[1, 2], cats_strings=["cat1", "cat2"], + output_names=['out']) + onx = op.to_onnx( + inputs=[('cat', StringTensorType())], + outputs=[('out', Int64TensorType())]) + oinf = OnnxInference(onx) + res = oinf.run({'cat': numpy.array(["cat1", "cat2", "cat1", "R"], + dtype=numpy.str_)}) + self.assertEqualArray( + res['out'], numpy.array([1, 2, 1, -1], dtype=numpy.int64)) + if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2) diff --git a/mlprodict/onnxrt/ops_cpu/_op_list.py b/mlprodict/onnxrt/ops_cpu/_op_list.py index f2b4d19f2..f8ce868fd 100644 --- a/mlprodict/onnxrt/ops_cpu/_op_list.py +++ b/mlprodict/onnxrt/ops_cpu/_op_list.py @@ -27,6 +27,7 @@ from .op_ceil import Ceil from .op_celu import Celu from .op_clip import Clip_6, Clip_11, Clip +from .op_category_mapper import CategoryMapper from .op_complex_abs import ComplexAbs from .op_compress import Compress from .op_concat import Concat diff --git a/mlprodict/onnxrt/ops_cpu/op_category_mapper.py b/mlprodict/onnxrt/ops_cpu/op_category_mapper.py new file mode 100644 index 000000000..cece6ad49 --- /dev/null +++ b/mlprodict/onnxrt/ops_cpu/op_category_mapper.py @@ -0,0 +1,60 @@ +# -*- encoding: utf-8 -*- +# pylint: disable=E0203,E1101,C0111 +""" +@file +@brief Runtime operator. +""" +import numpy +from ._op import OpRun + + +class CategoryMapper(OpRun): + + atts = {'cats_int64s': numpy.empty(0, dtype=numpy.int64), + 'cats_strings': numpy.empty(0, dtype=numpy.str_), + 'default_int64': -1, + 'default_string': b'', + } + + def __init__(self, onnx_node, desc=None, **options): + OpRun.__init__(self, onnx_node, desc=desc, + expected_attributes=CategoryMapper.atts, + **options) + if len(self.cats_int64s) != len(self.cats_strings): + raise RuntimeError( + "Lengths mismatch between cats_int64s (%d) and " + "cats_strings (%d)." % ( + len(self.cats_int64s), len(self.cats_strings))) + self.int2str_ = {} + self.str2int_ = {} + for a, b in zip(self.cats_int64s, self.cats_strings): + be = b.decode('utf-8') + self.int2str_[a] = be + self.str2int_[be] = a + + def _run(self, x): # pylint: disable=W0221 + if x.dtype == numpy.int64: + xf = x.ravel() + res = [self.int2str_.get(xf[i], self.default_string) + for i in range(0, xf.shape[0])] + return (numpy.array(res).reshape(x.shape), ) + + xf = x.ravel() + res = numpy.empty((xf.shape[0], ), dtype=numpy.int64) + for i in range(0, res.shape[0]): + res[i] = self.str2int_.get(xf[i], self.default_int64) + return (res.reshape(x.shape), ) + + def _infer_shapes(self, x): # pylint: disable=W0221 + if x.dtype == numpy.int64: + return (x.copy(dtype=numpy.str_), ) + return (x.copy(dtype=numpy.int64), ) + + def _infer_types(self, x): # pylint: disable=W0221 + if x.dtype == numpy.int64: + return (numpy.str_, ) + return (numpy.int64, ) + + def _infer_sizes(self, *args, **kwargs): + res = self.run(*args, **kwargs) + return (dict(temp=0), ) + res