diff --git a/_doc/sphinxdoc/source/api/tools.rst b/_doc/sphinxdoc/source/api/tools.rst index de4bc88b7..ac51a43a6 100644 --- a/_doc/sphinxdoc/source/api/tools.rst +++ b/_doc/sphinxdoc/source/api/tools.rst @@ -126,7 +126,7 @@ the possibility later to only show a part of a graph. **benchmark** -.. autosignature:: mlprodict.plotting.validate_graph.plot_validate_benchmark +.. autosignature:: mlprodict.plotting.plot_validate_benchmark Others ====== diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index 06bbbf60e..0f0936619 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -31,7 +31,7 @@ OnnxBatchNormalization, OnnxAcos, OnnxAcosh, OnnxAsin, OnnxAsinh, OnnxAtan, OnnxAtanh, OnnxAveragePool, - OnnxCast, OnnxCeil, OnnxClip, + OnnxCast, OnnxCastLike, OnnxCeil, OnnxClip, OnnxCompress, OnnxConcat, OnnxConv, OnnxConvTranspose, OnnxConstant, OnnxConstant_9, OnnxConstant_11, @@ -1083,6 +1083,25 @@ def test_onnxt_runtime_cast_in(self): python_tested.append(OnnxCast) + @wraplog() + def test_onnxt_runtime_cast_like(self): + x = numpy.array([1.5, 2.1, 3.1, 4.1]).astype( + numpy.float32) # pylint: disable=E1101 + y = numpy.array([1.]).astype(numpy.int64) # pylint: disable=E1101 + + for opset in range(15, get_opset_number_from_onnx() + 1): + with self.subTest(opset=opset): + onx = OnnxCastLike('X', 'Y', output_names=['Z'], + op_version=opset) + model_def = onx.to_onnx( + {'X': x, 'Y': y}, + outputs=[('Z', Int64TensorType())], + target_opset=opset) + got = OnnxInference(model_def).run({'X': x, 'Y': y}) + self.assertEqual(x.astype(numpy.int64), got['Z']) + + python_tested.append(OnnxCastLike) + @wraplog() def test_onnxt_runtime_ceil(self): self.common_test_onnxt_runtime_unary(OnnxCeil, numpy.ceil) diff --git a/mlprodict/onnxrt/ops_cpu/_op_list.py b/mlprodict/onnxrt/ops_cpu/_op_list.py index ff85fe9f2..f2b4d19f2 100644 --- a/mlprodict/onnxrt/ops_cpu/_op_list.py +++ b/mlprodict/onnxrt/ops_cpu/_op_list.py @@ -22,7 +22,7 @@ from .op_batch_normalization import BatchNormalization, BatchNormalization_14 from .op_binarizer import Binarizer from .op_broadcast_gradient_args import BroadcastGradientArgs -from .op_cast import Cast +from .op_cast import Cast, CastLike from .op_cdist import CDist from .op_ceil import Ceil from .op_celu import Celu diff --git a/mlprodict/onnxrt/ops_cpu/op_cast.py b/mlprodict/onnxrt/ops_cpu/op_cast.py index 2566302bc..0ce03ab9a 100644 --- a/mlprodict/onnxrt/ops_cpu/op_cast.py +++ b/mlprodict/onnxrt/ops_cpu/op_cast.py @@ -73,3 +73,29 @@ def _infer_types(self, x): # pylint: disable=W0221 def _infer_sizes(self, *args, **kwargs): res = self.run(*args, **kwargs) return (dict(temp=0), ) + res + + +class CastLike(OpRun): + + def __init__(self, onnx_node, desc=None, **options): + OpRun.__init__(self, onnx_node, desc=desc, **options) + + def _run(self, x, y): # pylint: disable=W0221 + if self.inplaces.get(0, False): + return self._run_inplace(x, y) + return (x.astype(y.dtype), ) + + def _run_inplace(self, x, y): + if x.dtype == y._dtype: + return (x, ) + return (x.astype(y.dtype), ) + + def _infer_shapes(self, x, y): # pylint: disable=W0221 + return (x.copy(dtype=y._dtype), ) + + def _infer_types(self, x, y): # pylint: disable=W0221 + return (y._dtype, ) + + def _infer_sizes(self, *args, **kwargs): + res = self.run(*args, **kwargs) + return (dict(temp=0), ) + res