diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index 18725e61d..d46221a98 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -2,6 +2,7 @@ @brief test log(time=120s) """ import unittest +import pprint import warnings import sys from logging import getLogger @@ -118,6 +119,7 @@ test_qgemm0, test_qgemm1) from mlprodict.onnxrt.ops_cpu.op_constant import Constant_12, Constant_11, Constant_9 from mlprodict.onnxrt.ops_shape.shape_excs import ShapeInferenceException +from mlprodict.plotting.text_plot import onnx_simple_text_plot try: numpy_str = numpy.str_ @@ -165,7 +167,6 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): if __name__ == "__main__": - import pprint print('-----------') pprint.pprint(sparse_support) print('-----------') @@ -184,6 +185,33 @@ def test_opset_skl2onnx(self): opset_skl2onnx = __max_supported_opset__ self.assertGreater(opset_skl2onnx, opset_mlprodict) + def _check_shape_inference(self, onnx_cl, model_def): + if onnx_cl in {OnnxCastLike}: + try: + shapeinf = OnnxShapeInference(model_def) + except Exception as e: + raise AssertionError( + "Unable to infer shape for:\n%s" + "" % onnx_simple_text_plot(model_def)) from e + try: + shape_results = shapeinf.run() + except Exception as e: + raise AssertionError( + "Unable to infer shape %r in\n%r\n." % ( + e, model_def)) from e + shape = shape_results.get() + try: + self.assertIn('X', shape) + self.assertIn('Y', shape) + self.assertIn('Z', shape) + self.assertEqual(shape['X'].shape, shape['Z'].shape) + self.assertEqual(shape['Z'].dtype, shape['Y'].dtype) + except Exception as e: + raise AssertionError( + "Discrepancies in\n%s\n--ONNX--\n%s" % ( + pprint.pformat(shape), + onnx_simple_text_plot(model_def))) from e + def common_expected_shapes_types(self, oinf, inputs, got, onnx_cl, model_def, raise_shape=False): expected_types = oinf.infer_types() @@ -363,6 +391,33 @@ def common_test_onnxt_runtime_binary(self, onnx_cl, np_fct, oinfpy = OnnxInference(model_def, runtime="python", inplace=True) validate_python_inference(oinfpy, {'X': X.astype(dtype)}) + # shape + if onnx_cl not in {OnnxSum, OnnxMatMul}: + shapeinf = OnnxShapeInference(model_def) + try: + shape_results = shapeinf.run() + except Exception as e: + raise AssertionError( + "Unable to infer shape %r in\n%r\n." % ( + e, model_def)) from e + shape = shape_results.get() + self.assertIn('X', shape) + self.assertIn('Y', shape) + if onnx_cl in {OnnxSub, OnnxMul, OnnxDiv, OnnxAdd, OnnxAnd, + OnnxOr, OnnxMod, OnnxMax, OnnxMin, OnnxPow}: + self.assertEqual(shape['X'].dtype, shape['Y'].dtype) + self.assertIn(shape['Y'].shape[0], shape['X'].shape[0]) + self.assertEqual(shape['X'].shape[1], shape['Y'].shape[1]) + elif onnx_cl in {OnnxLessOrEqual, OnnxGreater, OnnxGreaterOrEqual, + OnnxLess, OnnxEqual}: + self.assertEqual(shape['X'].dtype, numpy.float32) + self.assertEqual(shape['Y'].dtype, numpy.bool_) + self.assertIn(shape['Y'].shape[0], shape['X'].shape[0]) + self.assertEqual(shape['X'].shape[1], shape['Y'].shape[1]) + else: + self.assertEqual(shape['X'].shape, shape['Y'].shape) + self.assertEqual(shape['X'].dtype, shape['Y'].dtype) + # sparse idi = make_coo_matrix(numpy.identity(2)).astype(numpy.float32) X = make_coo_matrix(numpy.array( @@ -444,6 +499,7 @@ def test_onnxt_runtime_argmax(self): model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=opset) oinf = OnnxInference(model_def) + self._check_shape_inference(OnnxArgMax, model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) self.assertEqualArray(numpy.argmax( @@ -464,6 +520,7 @@ def test_onnxt_runtime_argmax(self): model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=opset) oinf = OnnxInference(model_def) + self._check_shape_inference(OnnxArgMax, model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) self.assertEqualArray(numpy.argmax(X, axis=1).ravel(), @@ -478,6 +535,7 @@ def test_onnxt_runtime_argmax(self): self.assertEqual(list(sorted(got)), ['Y']) self.assertEqualArray(numpy.argmax(X, axis=1).ravel(), got['Y'].ravel()) + self._check_shape_inference(OnnxArgMax, model_def) # sparse X = make_coo_matrix(X, dtype=numpy.float32) @@ -536,6 +594,7 @@ def test_onnxt_runtime_argmin(self): op_version=opset) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(clarg, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -556,6 +615,7 @@ def test_onnxt_runtime_argmin(self): op_version=opset) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxArgMin, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -566,6 +626,7 @@ def test_onnxt_runtime_argmin(self): op_version=opset) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxArgMin, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -911,6 +972,7 @@ def test_onnxt_runtime_batch_normalization(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxBatchNormalization, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqual(list(sorted(got)), ['Y']) @@ -934,6 +996,7 @@ def test_onnxt_runtime_batch_normalization(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxBatchNormalization, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqual(list(sorted(got)), ['Y']) @@ -1044,6 +1107,7 @@ def test_onnxt_runtime_cast_out(self): model_def = onx.to_onnx( {'X': x}, outputs=[('Y', outp())], target_opset=opset) + self._check_shape_inference(OnnxCast, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) if nptp == numpy.str_: @@ -1103,6 +1167,7 @@ def test_onnxt_runtime_cast_in(self): model_def = onx.to_onnx( {'X': xi}, outputs=[('Y', StringTensorType())], target_opset=opset) + self._check_shape_inference(OnnxCast, model_def) got = OnnxInference(model_def).run({'X': xi}) self.assertEqual( xi.astype(str).tolist(), got['Y'].tolist()) @@ -1121,8 +1186,9 @@ def test_onnxt_runtime_cast_like(self): op_version=opset) model_def = onx.to_onnx( {'X': x, 'Y': y}, - outputs=[('Z', Int64TensorType())], + outputs=[('Z', Int64TensorType([None]))], target_opset=opset) + self._check_shape_inference(OnnxCastLike, model_def) got = OnnxInference(model_def).run({'X': x, 'Y': y}) self.assertEqual(x.astype(numpy.int64), got['Z']) @@ -1185,6 +1251,7 @@ def test_onnxt_runtime_compress(self): model_def = onx.to_onnx({'X': x, 'cond': cond}, outputs=[('Y', FloatTensorType())], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxCompress, model_def) exp = numpy.compress(cond, x) oinf = OnnxInference(model_def) got = oinf.run({'X': x, 'cond': cond}) @@ -1234,6 +1301,7 @@ def test_onnxt_runtime_concat(self): 'Y': Y.astype(numpy.float32)}, outputs=[('Z', FloatTensorType([2]))], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxConcat, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X.astype(numpy.float32), 'Y': Y.astype(numpy.float32)}) @@ -1262,6 +1330,7 @@ def test_onnxt_runtime_constant_of_shape(self): model_def = onx.to_onnx({'X': x.astype(numpy.int64)}, outputs=[('Y', FloatTensorType())], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxConstantOfShape, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x.astype(numpy.int64)}) self.assertEqualArray(y, got['Y']) @@ -1297,6 +1366,7 @@ def test_onnxt_runtime_conv0(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxConv, model_def) for rt in ['python', 'onnxruntime1']: with self.subTest(runtime=rt): oinf = OnnxInference(model_def, runtime=rt) @@ -1376,6 +1446,7 @@ def test_onnxt_runtime_conv1(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxConvTranspose, model_def) for rt in ['python', 'onnxruntime1']: with self.subTest(runtime=rt): oinf = OnnxInference(model_def, runtime=rt) @@ -1467,6 +1538,7 @@ def test_onnxt_runtime_conv_transpose(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxConvTranspose, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqual(list(sorted(got)), ['Y']) @@ -1485,6 +1557,7 @@ def test_onnxt_runtime_conv_transpose_B(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x, 'W': W, 'B': B}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxConvTranspose, model_def) ys = [] for rt in ['python', 'onnxruntime1']: oinf = OnnxInference(model_def, runtime=rt) @@ -1508,6 +1581,7 @@ def test_onnxt_runtime_conv_transpose_1d(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxConvTranspose, model_def) oinf = OnnxInference(model_def, runtime="onnxruntime1") got = oinf.run({'X': x}) @@ -1603,6 +1677,7 @@ def test_onnxt_runtime_conv_transpose_3d(self): model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) oinf = OnnxInference(model_def) + self._check_shape_inference(OnnxConvTranspose, model_def) got = oinf.run({'X': x}) self.assertEqual(list(sorted(got)), ['Y']) self.assertEqualArray(y_with_padding, got['Y']) @@ -1727,6 +1802,7 @@ def test_onnxt_runtime_conv_transpose_dilation(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxConvTranspose, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqual(list(sorted(got)), ['Y']) @@ -1760,6 +1836,7 @@ def test_onnxt_runtime_conv_transpose_pads(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxConvTranspose, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqual(list(sorted(got)), ['Y']) @@ -1783,6 +1860,7 @@ def test_onnxt_runtime_cum_sum(self): model_def = onx.to_onnx({'X': x, 'axis': axis}, outputs=[('Y', DoubleTensorType())], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxCumSum, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x.astype(numpy.float64), 'axis': axis}) @@ -1923,6 +2001,7 @@ def test_onnxt_runtime_dequantize_linear(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxDequantizeLinear, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqualArray(exp, got['Y']) @@ -1938,6 +2017,7 @@ def test_onnxt_runtime_dequantize_linear(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxDequantizeLinear, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqualArray(exp, got['Y']) @@ -1988,6 +2068,7 @@ def test_onnxt_runtime_dropout(self): outputs=[('Y', FloatTensorType()), ('Z', FloatTensorType())], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxDropout, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y', 'Z']) @@ -2059,6 +2140,7 @@ def test_onnxt_runtime_eyelike(self): model_def = onx.to_onnx({'X': X.astype(numpy.int64)}, target_opset=get_opset_number_from_onnx(), outputs=[('Y', FloatTensorType())]) + self._check_shape_inference(OnnxEyeLike, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -2095,6 +2177,7 @@ def test_onnxt_runtime_flatten(self): model_def = node.to_onnx( {'X': x}, outputs=[('Y', FloatTensorType())], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxFlatten, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) new_shape = ((1, -1) if i == 0 @@ -2125,6 +2208,7 @@ def test_onnxt_runtime_gather_elements0(self): model_def = onx.to_onnx({'X': data, 'Y': indices}, outputs=[('Z', FloatTensorType())], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxGatherElements, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': data, 'Y': indices}) self.assertEqual(got['Z'].size, 0) @@ -2185,6 +2269,7 @@ def test_onnxt_runtime_gather_elements(self): model_def = onx.to_onnx({'X': data, 'Y': indices}, outputs=[('Z', FloatTensorType())], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxGatherElements, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': data, 'Y': indices}) exp = numpy.array([[4, 8, 3], @@ -2241,6 +2326,7 @@ def do_test_onnxt_runtime_gemm(self, runtime): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxGemm, model_def) model_def.ir_version = get_ir_version_from_onnx() oinf = OnnxInference(model_def, runtime=runtime) got = oinf.run({'X': X.astype(numpy.float32)}) @@ -2263,6 +2349,7 @@ def do_test_onnxt_runtime_gemm(self, runtime): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxGemm, model_def) if 'onnxruntime' in runtime: model_def.ir_version = get_ir_version_from_onnx() oinf = OnnxInference(model_def, runtime=runtime) @@ -2294,6 +2381,7 @@ def test_onnxt_runtime_global_average_pool(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxGlobalAveragePool, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqual(list(sorted(got)), ['Y']) @@ -2380,6 +2468,7 @@ def test_onnxt_runtime_lp_normalization(self): X = numpy.array([[1, 2], [3, -4]], dtype=numpy.float32) model_def = onx.to_onnx({'X': X}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxLpNormalization, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) exp = numpy.array([[0.3162278, 0.4472136], @@ -2410,6 +2499,7 @@ def test_onnxt_runtime_max_pool_1d_default(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx( {'X': X}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxMaxPool, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqualArray(exp, got['Y']) @@ -2454,6 +2544,7 @@ def test_onnxt_runtime_max_pool_2d(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx( {'X': X}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxMaxPool, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqualArray(exp, got['Y']) @@ -2488,6 +2579,7 @@ def test_onnxt_runtime_max_pool_2d(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx( {'X': X}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxMaxPool, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqualArray(exp, got['Y']) @@ -2561,6 +2653,7 @@ def test_onnxt_runtime_mean(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': idi.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceMean, model_def) X = numpy.array([[1, 2], [3, 4]], dtype=numpy.float64) oinf = OnnxInference(model_def) got = oinf.run({'X': X.astype(numpy.float32)}) @@ -2613,6 +2706,7 @@ def test_onnxt_runtime_pad(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'data': data, 'pads': pads}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxPad, model_def) oinf = OnnxInference(model_def) got = oinf.run({'data': data, 'pads': pads}) self.assertEqualArray(exp, got['Y']) @@ -2632,6 +2726,7 @@ def test_onnxt_runtime_pad(self): mode='reflect', op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'data': data, 'pads': pads}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxPad, model_def) oinf = OnnxInference(model_def) got = oinf.run({'data': data, 'pads': pads}) self.assertEqualArray(exp, got['Y']) @@ -2728,6 +2823,7 @@ def test_onnxt_runtime_qlinear_conv(self): 'y_scale': y_scale, 'y_zero_point': y_zero_point} model_def = node.to_onnx(inputs, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxQLinearConv, model_def) oinf = OnnxInference(model_def) got = oinf.run(inputs) self.assertEqualArray(output, got['y']) @@ -2931,6 +3027,7 @@ def test_onnxt_runtime_quantize_linear(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxQuantizeLinear, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqualArray(exp, got['Y']) @@ -2961,6 +3058,7 @@ def test_onnxt_runtime_range(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'starts': starts, 'ends': ends}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxRange, model_def) oinf = OnnxInference(model_def) exp = numpy.array([0, 4, 8], dtype=numpy.float32) got = oinf.run({'starts': starts, 'ends': ends}) @@ -2984,6 +3082,7 @@ def reduce_l1(x, axis, keepdims): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceL1, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X.astype(numpy.float32)}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3007,6 +3106,7 @@ def reduce_l1(x, axis, keepdims): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceL1, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3025,6 +3125,7 @@ def reduce_l2(x, axis, keepdims): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceL2, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X.astype(numpy.float32)}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3038,6 +3139,7 @@ def reduce_l2(x, axis, keepdims): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceL2, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3048,6 +3150,7 @@ def reduce_l2(x, axis, keepdims): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceL2, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3117,6 +3220,7 @@ def test_onnxt_runtime_reduce_max(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceMax, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X.astype(numpy.float32)}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3130,6 +3234,7 @@ def test_onnxt_runtime_reduce_max(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceMax, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3155,6 +3260,7 @@ def test_onnxt_runtime_reduce_mean(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceMean, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X.astype(numpy.float32)}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3167,6 +3273,7 @@ def test_onnxt_runtime_reduce_mean(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceMean, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3177,6 +3284,7 @@ def test_onnxt_runtime_reduce_mean(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceMean, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3272,6 +3380,7 @@ def test_onnxt_runtime_reduce_sum(self): op_version=opset) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxReduceSum, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X.astype(numpy.float32)}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3289,6 +3398,7 @@ def test_onnxt_runtime_reduce_sum(self): op_version=opset) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxReduceSum, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3344,6 +3454,7 @@ def test_onnxt_runtime_reduce_sum_noop_with_empty_axes(self): op_version=opset, noop_with_empty_axes=1) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxReduceSum, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X.astype(numpy.float32)}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3376,6 +3487,7 @@ def test_onnxt_runtime_reduce_sum_square(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReduceSumSquare, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X.astype(numpy.float32)}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3453,6 +3565,7 @@ def test_onnxt_runtime_reshape(self): X = numpy.array([[1, 2], [3, -4]], dtype=numpy.float32) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxReshape, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3566,6 +3679,7 @@ def test_onnxt_runtime_shape(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxShape, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqualArray(y, got['Y']) @@ -3598,6 +3712,7 @@ def test_onnxt_runtime_size(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxSize, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqualArray(y, got['Y']) @@ -3628,6 +3743,7 @@ def test_onnxt_runtime_slice(self): output_names=['Y'], op_version=opset) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxSlice, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqualArray(y, got['Y']) @@ -3648,6 +3764,7 @@ def test_onnxt_runtime_slice(self): op_version=opset) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxSlice, model_def) got = OnnxInference(model_def).run({'X': x}) self.assertEqualArray(y, got['Y']) @@ -3665,6 +3782,7 @@ def test_onnxt_runtime_slice(self): op_version=opset) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxSlice, model_def) got = OnnxInference(model_def).run({'X': x}) self.assertEqualArray(y, got['Y']) @@ -3772,6 +3890,7 @@ def test_onnxt_runtime_squeeze(self): 'X', axes=[1], output_names=['Y'], op_version=opset) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxSqueeze, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqualArray(y, got['Y']) @@ -3785,6 +3904,7 @@ def test_onnxt_runtime_squeeze(self): 'X', axes=[0], output_names=['Y'], op_version=opset) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxSqueeze, model_def) got = OnnxInference(model_def).run({'X': x}) self.assertEqualArray(y, got['Y']) python_tested.append(OnnxSqueeze) @@ -3824,6 +3944,7 @@ def test_onnxt_runtime_topk0(self): outputs=[('Y', FloatTensorType(X.shape)), ('Yi', Int64TensorType(X.shape))], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxTopK, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y', 'Yi']) @@ -3869,6 +3990,7 @@ def test_onnxt_runtime_topk(self): outputs=[('Y', FloatTensorType(X.shape)), ('Yi', Int64TensorType(X.shape))], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxTopK, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y', 'Yi']) @@ -3908,6 +4030,7 @@ def test_onnxt_runtime_topk2(self): outputs=[('Y', FloatTensorType(X.shape)), ('Yi', Int64TensorType(X.shape))], target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxTopK, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y', 'Yi']) @@ -3929,6 +4052,7 @@ def test_onnxt_runtime_transpose(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxTranspose, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3945,6 +4069,7 @@ def test_onnxt_runtime_transpose(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': X.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) + self._check_shape_inference(OnnxTranspose, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': X}) self.assertEqual(list(sorted(got)), ['Y']) @@ -3964,6 +4089,7 @@ def test_onnxt_runtime_unsqueeze(self): 'X', axes=[-2], output_names=['Y'], op_version=opset) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxUnsqueeze, model_def) oinf = OnnxInference(model_def) got = oinf.run({'X': x}) self.assertEqualArray(y, got['Y']) @@ -3978,6 +4104,7 @@ def test_onnxt_runtime_unsqueeze(self): 'X', axes=[2, 4, 5], output_names=['Y'], op_version=opset) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=opset) + self._check_shape_inference(OnnxUnsqueeze, model_def) got = OnnxInference(model_def).run({'X': x}) self.assertEqualArray(y, got['Y']) python_tested.append(OnnxUnsqueeze) @@ -4359,5 +4486,5 @@ def test_op_constant(self): if __name__ == "__main__": # Working - # TestOnnxrtPythonRuntime().test_onnxt_runtime_average_pool() + # TestOnnxrtPythonRuntime().test_onnxt_runtime_sub() unittest.main(verbosity=2) diff --git a/mlprodict/onnxrt/ops_shape/__init__.py b/mlprodict/onnxrt/ops_shape/__init__.py index 3e70ff0f1..aa8fd0da7 100644 --- a/mlprodict/onnxrt/ops_shape/__init__.py +++ b/mlprodict/onnxrt/ops_shape/__init__.py @@ -5,14 +5,23 @@ from ._element_unary import ( shape_abs, shape_acos, shape_acosh, shape_asin, shape_asinh, shape_atan, shape_atanh, - shape_ceil, shape_celu, + shape_castlike, shape_ceil, shape_celu, shape_clip, shape_cos, shape_cosh, shape_erf, shape_exp, shape_floor, shape_identity, shape_isnan, shape_leakyrelu, shape_log, shape_neg, shape_not, shape_reciprocal, shape_relu, shape_round, shape_sigmoid, shape_sign, shape_sin, shape_sinh, shape_softmax, shape_sqrt, shape_tan, shape_tanh) -from ._element_wise import shape_add, shape_mul, shape_div, shape_sub +from ._element_wise import ( + shape_add, shape_and, + shape_div, + shape_equal, + shape_greater, shape_greaterorequal, + shape_less, shape_lessorequal, + shape_max, shape_min, shape_mod, shape_mul, + shape_or, + shape_pow, + shape_sub) from ._op_shape_op import shape_det diff --git a/mlprodict/onnxrt/ops_shape/_element_unary.py b/mlprodict/onnxrt/ops_shape/_element_unary.py index 59db37c65..21a888d1f 100644 --- a/mlprodict/onnxrt/ops_shape/_element_unary.py +++ b/mlprodict/onnxrt/ops_shape/_element_unary.py @@ -64,6 +64,21 @@ def shape_atanh(known_shapes, node): return _element_unary(known_shapes, node) +def shape_castlike(known_shapes, node): + "Infers shape for operator CastLike." + x = known_shapes[node.input[0]] + if x.mtype != OnnxKind.Tensor: + raise ShapeInferenceException( # pragma: no cover + "Result %r must be a tensor." % x) + y = known_shapes[node.input[1]] + if y.mtype != OnnxKind.Tensor: + raise ShapeInferenceException( # pragma: no cover + "Result %r must be a tensor." % y) + cp = x.copy() + cp.dtype = y.dtype + return known_shapes.update(node.output[0], cp) + + def shape_ceil(known_shapes, node): "Infers shape for operator Ceil." return _element_unary(known_shapes, node) diff --git a/mlprodict/onnxrt/ops_shape/_element_wise.py b/mlprodict/onnxrt/ops_shape/_element_wise.py index 4a1539e16..6dcc985aa 100644 --- a/mlprodict/onnxrt/ops_shape/_element_wise.py +++ b/mlprodict/onnxrt/ops_shape/_element_wise.py @@ -32,8 +32,8 @@ def shape_add(known_shapes, node): return _element_wise(known_shapes, node) -def shape_sub(known_shapes, node): - "Infers shape for operator Sub." +def shape_and(known_shapes, node): + "Infers shape for operator And." return _element_wise(known_shapes, node) @@ -42,6 +42,61 @@ def shape_div(known_shapes, node): return _element_wise(known_shapes, node) +def shape_equal(known_shapes, node): + "Infers shape for operator Equal." + return _element_wise(known_shapes, node) + + +def shape_greater(known_shapes, node): + "Infers shape for operator Greater." + return _element_wise(known_shapes, node) + + +def shape_greaterorequal(known_shapes, node): + "Infers shape for operator GreaterOrEqual." + return _element_wise(known_shapes, node) + + +def shape_less(known_shapes, node): + "Infers shape for operator Less." + return _element_wise(known_shapes, node) + + +def shape_lessorequal(known_shapes, node): + "Infers shape for operator LessOrEqual." + return _element_wise(known_shapes, node) + + +def shape_max(known_shapes, node): + "Infers shape for operator Max." + return _element_wise(known_shapes, node) + + +def shape_min(known_shapes, node): + "Infers shape for operator Min." + return _element_wise(known_shapes, node) + + +def shape_mod(known_shapes, node): + "Infers shape for operator Mod." + return _element_wise(known_shapes, node) + + def shape_mul(known_shapes, node): "Infers shape for operator Mul." return _element_wise(known_shapes, node) + + +def shape_or(known_shapes, node): + "Infers shape for operator Or." + return _element_wise(known_shapes, node) + + +def shape_pow(known_shapes, node): + "Infers shape for operator Pow." + return _element_wise(known_shapes, node) + + +def shape_sub(known_shapes, node): + "Infers shape for operator Sub." + return _element_wise(known_shapes, node)