diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml_text.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml_text.py index 1ad71698d..ffb62798e 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml_text.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_ml_text.py @@ -58,6 +58,19 @@ def test_onnxrt_label_encoder_floats(self): self.assertEqualArray( res['out'], numpy.array([0.3, 0.4, 0.5, 0.4], dtype=numpy.float32)) + def test_onnxrt_label_encoder_string_floats(self): + + op = OnnxLabelEncoder( + 'text', op_version=get_opset_number_from_onnx(), + keys_strings=['AA', 'BB', 'CC'], + values_floats=[0.1, 0.2, 0.3], + output_names=['out']) + + onx = op.to_onnx(inputs=[('text', StringTensorType())]) + oinf = OnnxInference(onx) + res = oinf.run({'text': numpy.array(['AA', 'DD']).reshape((-1, 1))}) + self.assertEqualArray(res['out'], numpy.array([0.1, 0])) + def test_onnxrt_label_encoder_raise(self): self.assertRaise( @@ -68,15 +81,6 @@ def test_onnxrt_label_encoder_raise(self): output_names=['out']), TypeError) - op = OnnxLabelEncoder( - 'text', op_version=get_opset_number_from_onnx(), - keys_strings=['AA', 'BB', 'CC'], - values_floats=[0.1, 0.2, 0.3], - output_names=['out']) - - onx = op.to_onnx(inputs=[('text', StringTensorType())]) - self.assertRaise(lambda: OnnxInference(onx), RuntimeError) - op = OnnxLabelEncoder( 'text', op_version=get_opset_number_from_onnx(), keys_strings=['AA', 'BB', 'CC'], diff --git a/mlprodict/onnxrt/ops_cpu/op_label_encoder.py b/mlprodict/onnxrt/ops_cpu/op_label_encoder.py index 0b0d73142..537f29294 100644 --- a/mlprodict/onnxrt/ops_cpu/op_label_encoder.py +++ b/mlprodict/onnxrt/ops_cpu/op_label_encoder.py @@ -45,6 +45,11 @@ def __init__(self, onnx_node, desc=None, **options): self.keys_int64s, self.values_floats)} self.default_ = self.default_int64 self.dtype_ = numpy.float32 + elif len(self.keys_strings) > 0 and len(self.values_floats) > 0: + self.classes_ = {k.decode('utf-8'): v for k, v in zip( + self.keys_strings, self.values_floats)} + self.default_ = self.default_float + self.dtype_ = numpy.float32 elif len(self.keys_strings) > 0 and len(self.values_int64s) > 0: self.classes_ = {k.decode('utf-8'): v for k, v in zip( self.keys_strings, self.values_int64s)}