diff --git a/_doc/examples/plot_experimental.py b/_doc/examples/plot_experimental.py new file mode 100644 index 000000000..07da936b1 --- /dev/null +++ b/_doc/examples/plot_experimental.py @@ -0,0 +1,94 @@ +""" +.. _l-example-experimental: + +Compares implementation of stanard function +=========================================== + +The following function benchmark different implementation +of standard function. + + + +.. contents:: + :local: + +Einsum +++++++ +""" +import numpy +import pandas +import matplotlib.pyplot as plt +from tqdm import tqdm +from cpyquickhelper.numbers.speed_measure import measure_time +from mlprodict.testing.experimental_c import custom_einsum_double +from onnxruntime import InferenceSession +from skl2onnx.algebra.onnx_ops import OnnxEinsum +from skl2onnx.common.data_types import DoubleTensorType +import onnx + + +def build_ort_einsum(equation, op_version=12): + node = OnnxEinsum('x', 'y', equation=equation, + op_version=op_version, + output_names=['z']) + onx = node.to_onnx(inputs=[('x', DoubleTensorType()), ('y', DoubleTensorType())], + target_opset=op_version) + sess = InferenceSession(onx.SerializeToString()) + return lambda x, y: sess.run(None, {'x': x, 'y': y}) + + +equation = "bsnh,btnh->bnts" +ort_einsum = build_ort_einsum(equation) +res = [] +for dim in tqdm([8, 16, 32, 64, 128, 256, 512]): + x = numpy.random.rand(1, dim, 12, 64) + y = numpy.random.rand(1, dim, 12, 64) + + ort_einsum(x, y) + + ctx = dict(equation=equation, x=x, y=y, einsum=numpy.einsum) + obs = measure_time("einsum(equation, x, y)", div_by_number=True, context=ctx, + repeat=5, number=5) + obs['dim'] = dim + obs['fct'] = 'numpy.einsum' + res.append(obs) + + ctx['einsum'] = ort_einsum + obs = measure_time("einsum(x, y)", div_by_number=True, context=ctx, + repeat=5, number=5) + obs['dim'] = dim + obs['fct'] = 'ort_einsum' + res.append(obs) + + ctx['einsum'] = custom_einsum_double + obs = measure_time("einsum(equation, x, y)", div_by_number=True, context=ctx, + repeat=5, number=5) + obs['dim'] = dim + obs['fct'] = 'custom_einsum_double' + res.append(obs) + +df = pandas.DataFrame(res) +df +print(df.T) + +########################################### +# Pivot + +piv = df.pivot('dim', 'fct', 'average') +piv + +########################################### +# Ratios + +rs = piv.copy() +rs['custom_einsum_double'] = rs['numpy.einsum'] / rs['custom_einsum_double'] +rs['ort_einsum'] = rs['numpy.einsum'] / rs['ort_einsum'] +rs['numpy.einsum'] = 1. +rs + +########################################### +# Graphs. +fig, ax = plt.subplots(1, 2, figsize=(12, 4)) +piv.plot(logx=True, logy=True, ax=ax[0], title="Einsum benchmark") +rs.plot(logx=True, ax=ax[1], title="Einsum Speedup, baseline=numpy") +plt.show() diff --git a/_doc/sphinxdoc/source/api/index.rst b/_doc/sphinxdoc/source/api/index.rst index e8b0bb656..bac2a3976 100644 --- a/_doc/sphinxdoc/source/api/index.rst +++ b/_doc/sphinxdoc/source/api/index.rst @@ -13,4 +13,5 @@ API sklapi asv validation + testing tools diff --git a/_doc/sphinxdoc/source/api/testing.rst b/_doc/sphinxdoc/source/api/testing.rst new file mode 100644 index 000000000..15cb93200 --- /dev/null +++ b/_doc/sphinxdoc/source/api/testing.rst @@ -0,0 +1,17 @@ + +testing +======= + +.. contents:: + :local: + +Experimental +++++++++++++ + +Experimental implementations for algorithm. + +.. autosignature:: mlprodict.testing.experimental.custom_einsum + +.. autosignature:: mlprodict.testing.experimental_c.custom_einsum_double + +.. autosignature:: mlprodict.testing.experimental.custom_pad diff --git a/_unittests/ut__skl2onnx/test_sklearn_cast_transformer.py b/_unittests/ut__skl2onnx/test_sklearn_cast_transformer.py index b9d21cdf8..d0d97f850 100644 --- a/_unittests/ut__skl2onnx/test_sklearn_cast_transformer.py +++ b/_unittests/ut__skl2onnx/test_sklearn_cast_transformer.py @@ -4,12 +4,12 @@ import unittest import math import numpy +from onnxruntime import InferenceSession from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.tree import DecisionTreeRegressor -from onnxruntime import InferenceSession from skl2onnx.sklapi import CastTransformer from skl2onnx import convert_sklearn, to_onnx from skl2onnx.common.data_types import ( diff --git a/_unittests/ut__skl2onnx/test_sklearn_gaussian_mixture_converter.py b/_unittests/ut__skl2onnx/test_sklearn_gaussian_mixture_converter.py index 8fbf4361c..2cbeb60b1 100644 --- a/_unittests/ut__skl2onnx/test_sklearn_gaussian_mixture_converter.py +++ b/_unittests/ut__skl2onnx/test_sklearn_gaussian_mixture_converter.py @@ -3,10 +3,10 @@ """ import unittest import numpy as np -from sklearn.datasets import load_iris -from sklearn.mixture import GaussianMixture, BayesianGaussianMixture from onnxruntime import InferenceSession from onnxruntime.capi.onnxruntime_pybind11_state import Fail as OrtFail # pylint: disable=E0611 +from sklearn.datasets import load_iris +from sklearn.mixture import GaussianMixture, BayesianGaussianMixture from skl2onnx import convert_sklearn, to_onnx from skl2onnx.common.data_types import FloatTensorType from mlprodict.testing.test_utils import dump_data_and_model, TARGET_OPSET diff --git a/_unittests/ut__skl2onnx/test_sklearn_gaussian_process.py b/_unittests/ut__skl2onnx/test_sklearn_gaussian_process.py index cc84c397d..f9e7f1dd7 100644 --- a/_unittests/ut__skl2onnx/test_sklearn_gaussian_process.py +++ b/_unittests/ut__skl2onnx/test_sklearn_gaussian_process.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd from numpy.testing import assert_almost_equal +from onnxruntime import __version__ as ort_version from sklearn.datasets import load_iris from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import ( @@ -18,7 +19,6 @@ from pyquickhelper.texthelper import compare_module_version from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType from skl2onnx import to_onnx, __version__ as skl2_vers -from onnxruntime import __version__ as ort_version from mlprodict.onnxrt import OnnxInference from mlprodict.testing.test_utils import ( dump_data_and_model, fit_regression_model, TARGET_OPSET) diff --git a/_unittests/ut__skl2onnx/test_sklearn_pipeline.py b/_unittests/ut__skl2onnx/test_sklearn_pipeline.py index 48e7b2008..4b8dce6bf 100644 --- a/_unittests/ut__skl2onnx/test_sklearn_pipeline.py +++ b/_unittests/ut__skl2onnx/test_sklearn_pipeline.py @@ -8,6 +8,7 @@ import numpy from numpy.testing import assert_almost_equal import pandas +from onnxruntime import __version__ as ort_version, InferenceSession from sklearn import __version__ as sklearn_version from sklearn import datasets from sklearn.compose import ColumnTransformer @@ -23,7 +24,6 @@ from skl2onnx import convert_sklearn from skl2onnx.common.data_types import ( FloatTensorType, Int64TensorType, StringTensorType) -from onnxruntime import __version__ as ort_version, InferenceSession from mlprodict.testing.test_utils import ( dump_data_and_model, fit_classification_model) diff --git a/_unittests/ut_documentation/test_run_notebooks_onnx_sbs.py b/_unittests/ut_documentation/test_run_notebooks_onnx_sbs.py index d6cb968a0..eaa1e0e3f 100644 --- a/_unittests/ut_documentation/test_run_notebooks_onnx_sbs.py +++ b/_unittests/ut_documentation/test_run_notebooks_onnx_sbs.py @@ -4,6 +4,7 @@ """ import os import unittest +from onnxruntime import __version__ as ort_version from sklearn.exceptions import ConvergenceWarning try: from sklearn.utils._testing import ignore_warnings @@ -16,7 +17,6 @@ add_missing_development_version, ExtTestCase ) from skl2onnx import __version__ as skl2onnx_version -from onnxruntime import __version__ as ort_version import mlprodict diff --git a/_unittests/ut_onnx_conv/test_onnx_conv_knn.py b/_unittests/ut_onnx_conv/test_onnx_conv_knn.py index cc3de40b4..9b6370911 100644 --- a/_unittests/ut_onnx_conv/test_onnx_conv_knn.py +++ b/_unittests/ut_onnx_conv/test_onnx_conv_knn.py @@ -7,6 +7,7 @@ import numpy from pandas import DataFrame from scipy.spatial.distance import cdist as scipy_cdist +from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument as OrtInvalidArgument # pylint: disable=E0611 from pyquickhelper.pycode import ExtTestCase from sklearn.calibration import CalibratedClassifierCV from sklearn.datasets import load_iris, make_regression @@ -24,7 +25,6 @@ from skl2onnx.common.data_types import Int64TensorType import skl2onnx from skl2onnx.algebra.complex_functions import onnx_cdist -from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument as OrtInvalidArgument # pylint: disable=E0611 from mlprodict.onnx_conv import ( register_converters, to_onnx) from mlprodict.onnxrt import OnnxInference diff --git a/_unittests/ut_onnxrt/test_onnxrt_side_by_side.py b/_unittests/ut_onnxrt/test_onnxrt_side_by_side.py index f914ed0ba..513e55aac 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_side_by_side.py +++ b/_unittests/ut_onnxrt/test_onnxrt_side_by_side.py @@ -6,10 +6,10 @@ from logging import getLogger import numpy import pandas +from onnxruntime import __version__ as ort_version from sklearn.gaussian_process.kernels import RBF, ConstantKernel as CK, Sum from pyquickhelper.pycode import ExtTestCase from pyquickhelper.texthelper.version_helper import compare_module_version -from onnxruntime import __version__ as ort_version from skl2onnx.common.data_types import FloatTensorType try: from skl2onnx.operator_converters.gaussian_process import convert_kernel diff --git a/_unittests/ut_onnxrt/test_onnxrt_validate_bug.py b/_unittests/ut_onnxrt/test_onnxrt_validate_bug.py index b27260359..53c60bb44 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_validate_bug.py +++ b/_unittests/ut_onnxrt/test_onnxrt_validate_bug.py @@ -5,9 +5,9 @@ import unittest import numpy import onnx +from onnxruntime import InferenceSession from pyquickhelper.pycode import ExtTestCase from skl2onnx.algebra.onnx_ops import OnnxAdd, OnnxMatMul # pylint: disable=E0611 -from onnxruntime import InferenceSession from mlprodict.onnxrt import OnnxInference from mlprodict.tools import get_opset_number_from_onnx diff --git a/_unittests/ut_onnxrt/test_rt_valid_model_gaussian_process_ort.py b/_unittests/ut_onnxrt/test_rt_valid_model_gaussian_process_ort.py index 5cab10209..cc7e052a8 100644 --- a/_unittests/ut_onnxrt/test_rt_valid_model_gaussian_process_ort.py +++ b/_unittests/ut_onnxrt/test_rt_valid_model_gaussian_process_ort.py @@ -4,6 +4,7 @@ import unittest from logging import getLogger import numpy +from onnxruntime import __version__ as ort_version from pyquickhelper.loghelper import fLOG from pyquickhelper.pycode import ExtTestCase, skipif_circleci from pyquickhelper.texthelper.version_helper import compare_module_version @@ -16,7 +17,6 @@ from sklearn.gaussian_process.kernels import RBF, ExpSineSquared from skl2onnx import __version__ as skl2onnx_version from skl2onnx.common.data_types import FloatTensorType -from onnxruntime import __version__ as ort_version from mlprodict.onnxrt.validate import enumerate_validated_operator_opsets from mlprodict.onnxrt import OnnxInference from mlprodict.tools.asv_options_helper import get_ir_version_from_onnx diff --git a/_unittests/ut_onnxrt/test_rt_valid_model_gaussian_process_ort2.py b/_unittests/ut_onnxrt/test_rt_valid_model_gaussian_process_ort2.py index 29480b9b4..e45820a86 100644 --- a/_unittests/ut_onnxrt/test_rt_valid_model_gaussian_process_ort2.py +++ b/_unittests/ut_onnxrt/test_rt_valid_model_gaussian_process_ort2.py @@ -3,6 +3,7 @@ """ import unittest from logging import getLogger +from onnxruntime import __version__ as ort_version from pyquickhelper.loghelper import fLOG from pyquickhelper.pycode import ExtTestCase, skipif_circleci from pyquickhelper.texthelper.version_helper import compare_module_version @@ -12,7 +13,6 @@ except ImportError: from sklearn.utils.testing import ignore_warnings from skl2onnx import __version__ as skl2onnx_version -from onnxruntime import __version__ as ort_version from mlprodict.onnxrt.validate import enumerate_validated_operator_opsets diff --git a/_unittests/ut_testing/test_experimental.py b/_unittests/ut_testing/test_experimental.py index f066c0d81..b086d8499 100644 --- a/_unittests/ut_testing/test_experimental.py +++ b/_unittests/ut_testing/test_experimental.py @@ -5,14 +5,16 @@ import numpy from onnx import helper, TensorProto from onnxruntime import InferenceSession -from pyquickhelper.pycode import ExtTestCase -from mlprodict.testing.experimental import custom_pad +from pyquickhelper.pycode import ExtTestCase, is_travis_or_appveyor +from mlprodict.testing.experimental import custom_pad, custom_einsum +from mlprodict.testing.experimental_c import ( # pylint: disable=E0611 + custom_einsum_double, custom_einsum_int64) from mlprodict.tools import get_opset_number_from_onnx class TestExperimental(ExtTestCase): - def ort_path(self, x, pads): + def ort_path_pad(self, x, pads): pads = list(pads[:, 0]) + list(pads[:, 1]) X = helper.make_tensor_value_info( 'X', TensorProto.FLOAT, x.shape) # pylint: disable=E1101 @@ -86,36 +88,36 @@ def test_experimental_pad_552(self): arr = numpy.random.rand(5, 5, 2).astype(numpy.float32) paddings = numpy.array([1, 1, 1, 1, 1, 1]).reshape((-1, 2)) - self.fct_test(custom_pad, self.ort_path, arr, paddings) + self.fct_test(custom_pad, self.ort_path_pad, arr, paddings) def test_experimental_pad_positive_ort(self): arr = (numpy.arange(6) + 10).astype(numpy.float32) paddings = numpy.array([1, 1]).reshape((-1, 2)) * 2 - self.fct_test(custom_pad, self.ort_path, arr, paddings) + self.fct_test(custom_pad, self.ort_path_pad, arr, paddings) arr = (numpy.arange(6) + 10).astype(numpy.float32) paddings = numpy.array([1, 1]).reshape((-1, 2)) - self.fct_test(custom_pad, self.ort_path, arr, paddings) + self.fct_test(custom_pad, self.ort_path_pad, arr, paddings) arr = (numpy.arange(6).reshape((2, -1)) + 10).astype(numpy.float32) paddings = numpy.array([1, 1, 1, 1]).reshape((-1, 2)) * 2 - self.fct_test(custom_pad, self.ort_path, arr, paddings) + self.fct_test(custom_pad, self.ort_path_pad, arr, paddings) arr = (numpy.arange(6).reshape((2, -1)) + 10).astype(numpy.float32) paddings = numpy.array([1, 1, 2, 2]).reshape((-1, 2)) - self.fct_test(custom_pad, self.ort_path, arr, paddings) + self.fct_test(custom_pad, self.ort_path_pad, arr, paddings) arr = (numpy.arange(6).reshape((2, -1)) + 10).astype(numpy.float32) paddings = numpy.array([1, 1, 1, 1]).reshape((-1, 2)) - self.fct_test(custom_pad, self.ort_path, arr, paddings) + self.fct_test(custom_pad, self.ort_path_pad, arr, paddings) arr = (numpy.arange(6).reshape((1, 2, -1)) + 10).astype(numpy.float32) paddings = numpy.array([1, 1, 1, 1, 1, 1]).reshape((-1, 2)) - self.fct_test(custom_pad, self.ort_path, arr, paddings) + self.fct_test(custom_pad, self.ort_path_pad, arr, paddings) arr = (numpy.arange(6).reshape((1, 2, -1)) + 10).astype(numpy.float32) paddings = numpy.array([1, 1, 1, 1, 1, 1]).reshape((-1, 2)) * 2 - self.fct_test(custom_pad, self.ort_path, arr, paddings) + self.fct_test(custom_pad, self.ort_path_pad, arr, paddings) def test_experimental_pad_negative(self): arr = numpy.arange(6) + 10 @@ -123,6 +125,80 @@ def test_experimental_pad_negative(self): self.assertRaise(lambda: custom_pad( arr, paddings), NotImplementedError) + def test_experimental_einsum(self): + eq = "bsnh,btnh->bnts" + + x = numpy.arange(8).reshape((1, 2, 2, 2)) + y = numpy.arange(8).reshape((1, 2, 2, 2)) + 100 + ein = numpy.einsum(eq, x, y) + ein2 = custom_einsum(eq, x, y) + self.assertEqual(ein.shape, ein2.shape) + self.assertEqualArray(ein, ein2) + + x = numpy.random.rand(1, 8, 3, 5) + y = numpy.random.rand(1, 8, 3, 5) + bady1 = numpy.random.rand(2, 8, 3, 5) + bady2 = numpy.random.rand(1, 8, 3, 6) + ein = numpy.einsum(eq, x, y) + self.assertRaise(lambda: custom_einsum( + eq, x.astype(int), y), RuntimeError) + self.assertRaise(lambda: custom_einsum( + "bsnhj,btnh->bnts", x, y), ValueError) + self.assertRaise(lambda: custom_einsum( + "bsnh,btnhj->bnts", x, y), ValueError) + self.assertRaise(lambda: custom_einsum(eq, x, bady1), ValueError) + self.assertRaise(lambda: custom_einsum(eq, x, bady2), ValueError) + self.assertRaise(lambda: custom_einsum(eq, bady1, x), ValueError) + self.assertRaise(lambda: custom_einsum(eq, bady2, x), ValueError) + self.assertRaise( + lambda: custom_einsum( + "bsnhv,btnhv->bnts", numpy.random.rand(1, 8, 3, 5, 2), + numpy.random.rand(1, 8, 3, 5, 2)), NotImplementedError) + ein2 = custom_einsum(eq, x, y) + self.assertEqual(ein.shape, ein2.shape) + self.assertEqualArray(ein, ein2) + + def is_ci_win(self): + return is_travis_or_appveyor() == "appveyor" + + def test_experimental_einsum_c(self): + eq = "bsnh,btnh->bnts" + + x = numpy.arange(8).reshape((1, 2, 2, 2)).astype(numpy.int64) + y = (numpy.arange(8).reshape((1, 2, 2, 2)) + 100).astype(numpy.int64) + ein = numpy.einsum(eq, x, y) + ein2 = custom_einsum_int64(eq, x, y) + self.assertEqual(ein.shape, ein2.shape) + self.assertEqualArray(ein, ein2) + + x = numpy.random.rand(1, 8, 3, 5) + y = numpy.random.rand(1, 8, 3, 5) + bady1 = numpy.random.rand(2, 8, 3, 5) + bady2 = numpy.random.rand(1, 8, 3, 6) + ein = numpy.einsum(eq, x, y) + if not self.is_ci_win(): + # It crashes on appveyor. + self.assertRaise(lambda: custom_einsum_double( + "bsnhj,btnh->bnts", x, y), RuntimeError) + self.assertRaise(lambda: custom_einsum_double( + "bsnh,btnhj->bnts", x, y), RuntimeError) + self.assertRaise(lambda: custom_einsum_double( + eq, x, bady1), RuntimeError) + self.assertRaise(lambda: custom_einsum_double( + eq, x, bady2), RuntimeError) + self.assertRaise(lambda: custom_einsum_double( + eq, bady1, x), RuntimeError) + self.assertRaise(lambda: custom_einsum_double( + eq, bady2, x), RuntimeError) + self.assertRaise( + lambda: custom_einsum_double( + "bsnhv,btnhv->bnts", numpy.random.rand(1, 8, 3, 5, 2), + numpy.random.rand(1, 8, 3, 5, 2)), RuntimeError) + ein2 = custom_einsum_double(eq, x, y) + self.assertEqual(ein.shape, ein2.shape) + self.assertEqualArray(ein, ein2) + if __name__ == "__main__": + # TestExperimental().test_experimental_einsum_c() unittest.main() diff --git a/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp b/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp index ceb98209d..87a4ce3df 100644 --- a/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp +++ b/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp @@ -319,7 +319,7 @@ void gemm(bool transA, bool transB, else { // a A B + b C, dimension = M * N NTYPE* begin; - register NTYPE val; + NTYPE val; NTYPE val0; size_t i, j, k, maxc=0; const NTYPE *pA, *pB; @@ -346,7 +346,7 @@ void gemm(bool transA, bool transB, else { // a A B + b C, dimension = M * N NTYPE* begin; - register NTYPE val; + NTYPE val; NTYPE val0; size_t i, j, k, maxc=0; const NTYPE *pA, *pB; diff --git a/mlprodict/testing/experimental.py b/mlprodict/testing/experimental.py index 2e27b79e2..341c2127a 100644 --- a/mlprodict/testing/experimental.py +++ b/mlprodict/testing/experimental.py @@ -2,6 +2,7 @@ @file @brief Experimental implementation. """ +from collections import OrderedDict import numpy @@ -80,3 +81,198 @@ def custom_pad(arr, paddings, constant=0, debug=False): # final return res.reshape(new_shape) + + +def custom_einsum(equation, x, y, debug=False): + """ + Experimental implementation of operator Einsum + when it does a matrix multiplication. + Case: ``bsnh,btnh->bnts`` with shapes + `(1,512,12,64)` and `(1,512,12,64)`. + + :param equation: equation + :param x: first matrix + :param y: second matrix + :param debug: display internal information + :return: result of *einsum* + + This implementation does not any transpose, + it does a direct computation of the final result. + It does not implementation diagonal summation (square product). + """ + def _check_eq(eq, sh): + if len(eq) != len(sh): + raise ValueError( + "Unable to map equation %r to shape %r." % (eq, sh)) + + def _split(eq, sh): + dx = OrderedDict((e, (v, i)) for i, (e, v) in enumerate(zip(eq, sh))) + return dx + + def _interpret(dx, dy, eqr): + c_uni = [] + c_trp = [] + c_sum = [] + for r in eqr: + if r in dx: + if r in dy: + if dx[r][0] != dy[r][0]: + raise ValueError( + "Dimension mismatch for letter " + "%r dx=%r dy=%r." % (r, dx, dy)) + c_trp.append(r) + else: + c_uni.append((r, None)) + elif r in dy: + c_uni.append((None, r)) + else: + raise ValueError( + "Unexpected letter %r in result %r." % (r, eqr)) + for c in dx: + if c not in eqr: + if c not in dy: + raise ValueError( + "Unable to guess what to do with column %r (left side)" % c) + if dx[c][0] != dy[c][0]: + raise ValueError( + "Dimension mismatch for letter " + "%r dx=%r dy=%r." % (c, dx, dy)) + c_sum.append(c) + for c in dy: + if c not in eqr and c not in dx: + raise ValueError( + "Unable to guess what to do with column %r (right side)" % c) + shape = OrderedDict() + for i, r in enumerate(eqr): + if r in c_trp: + shape[r] = (dx[r][0], i) + else: + for a, b in c_uni: + if a == r: + shape[r] = (dx[r][0], i) + break + if b == r: + shape[r] = (dy[r][0], i) + break + if len(shape) != len(eqr): + raise RuntimeError( + "Unable to compute the output shape " + "dx=%r dy=%r eqr=%r got shape=%r." % (dx, dy, eqr, shape)) + return shape, c_trp, c_uni, c_sum + + def _inc(d): + t = 1 + drev = list(reversed(d.items())) + res = [] + for c, (sh, p) in drev: + res.append((c, (t, p))) + t *= sh + return OrderedDict(reversed(res)) + + def prod(seq): + p = 1 + for s in seq: + p *= s + return p + + def get_index(cd, shape, index, col_sum): + ind = 0 + for c, i in zip(shape, index): + if c in cd: + inc = cd[c][0] + ind += inc * i + return ind, cd[col_sum][0] + + def get_incs(cd, shape): + incs = [] + for c in shape: + inc = cd[c][0] if c in cd else 0 + incs.append(inc) + return incs + + if x.dtype != y.dtype: + raise RuntimeError("x and y must have the same dtype.") + eqx = equation.split(',')[0] + eqy = equation.split(',')[-1].split('->')[0] + eqr = equation.split('->')[-1] + _check_eq(eqx, x.shape) + _check_eq(eqy, y.shape) + dx = _split(eqx, x.shape) + dy = _split(eqy, y.shape) + shape, __, _, c_sum = _interpret(dx, dy, eqr) + cdx = _inc(dx) + cdy = _inc(dy) + xrav = x.ravel() + yrav = y.ravel() + full_size = prod(v[0] for v in shape.values()) + zrav = numpy.empty((full_size, ), dtype=x.dtype) + + # loop + if len(c_sum) != 1: + raise NotImplementedError( + "More than one summation indices %r in equation %r." % ( + c_sum, equation)) + zeros = numpy.zeros((1, ), dtype=x.dtype) + shape_dims = [v[0] for v in shape.values()] + index = [0 for s in shape] + len_index = len(index) + loop_size = dx[c_sum[0]][0] + + i_left_loop, inc_left = get_index(cdx, shape, index, c_sum[0]) + i_right_loop, inc_right = get_index(cdy, shape, index, c_sum[0]) + left_incs = get_incs(cdx, shape) + right_incs = get_incs(cdy, shape) + + if debug: + def MakeString(*args): + return "".join(map(str, args)) + + print(MakeString("equation=", equation)) + print(MakeString("c_sum=", c_sum)) + print(MakeString("full_size=", full_size)) + print(MakeString("loop_size=", loop_size)) + print(MakeString("i_left_loop=", i_left_loop)) + print(MakeString("i_right_loop=", i_right_loop)) + print(MakeString("inc_left=", inc_left)) + print(MakeString("inc_right=", inc_right)) + print(MakeString("left_incs=", left_incs)) + print(MakeString("right_incs=", right_incs)) + print(MakeString("shape=", shape)) + print(MakeString("cdx=", cdx)) + print(MakeString("cdy=", cdy)) + + for i in range(0, full_size): + + i_left = i_left_loop + i_right = i_right_loop + + # summation + add = zeros[0] + for _ in range(loop_size): + add += xrav[i_left] * yrav[i_right] + i_left += inc_left + i_right += inc_right + zrav[i] = add + + if debug: + print(MakeString( + " -- index=", index, " ii=", i, + " i_left_loop=", i_left_loop, " i_right_loop=", i_right_loop, + " add=", add)) + + # increment + pos = len_index - 1 + index[pos] += 1 + i_left_loop += left_incs[pos] + i_right_loop += right_incs[pos] + while pos > 0 and index[pos] >= shape_dims[pos]: + i_left_loop -= left_incs[pos] * index[pos] + i_right_loop -= right_incs[pos] * index[pos] + index[pos] = 0 + pos -= 1 + index[pos] += 1 + i_left_loop += left_incs[pos] + i_right_loop += right_incs[pos] + + new_shape = tuple(v[0] for v in shape.values()) + return zrav.reshape(new_shape) diff --git a/mlprodict/testing/experimental_c.cpp b/mlprodict/testing/experimental_c.cpp new file mode 100644 index 000000000..b6b33ec2b --- /dev/null +++ b/mlprodict/testing/experimental_c.cpp @@ -0,0 +1,409 @@ +// Inspired from +// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_regressor.cc. + +#if !defined(_CRT_SECURE_NO_WARNINGS) +#define _CRT_SECURE_NO_WARNINGS +#endif + +#include +#include +#include +#include +#include +#include + +#ifndef SKIP_PYTHON +//#include +#include +#include +#include +//#include + +#if USE_OPENMP +#include +#endif + +namespace py = pybind11; +#endif + +#include "experimental_c_helper.hpp" + + +//////////////// +// begin: einsum +//////////////// + +typedef std::pair mapshape_element; + +class mapshape_type { + protected: + std::map container; + std::vector order; + public: + mapshape_type() : container() {} + inline size_t size() const { return container.size(); } + inline const mapshape_element& at(const char& c) const { return container.at(c); } + inline const mapshape_element& value(size_t i) const { return container.at(order[i]); } + inline char key(size_t i) const { return order[i]; } + void clear() { + container.clear(); + order.clear(); + } + void add(char c, const mapshape_element& el) { + container[c] = el; + order.push_back(c); + } + bool has_key(const char& key) const { + return container.find(key) != container.end(); + } +}; + +template <> +inline void MakeStringInternal(std::ostringstream& ss, const mapshape_type& t) noexcept { + for(size_t i = 0; i < t.size(); ++i) { + ss << t.key(i) << ":" << t.value(i).first << "," << t.value(i).second << " "; + } +} + + +template +void _check_eq(const std::string&eq, const TYPE& sh) { + if (eq.size() != sh.size()) + throw std::runtime_error(MakeString( + "Unable to map equation ", eq, " to shape ", sh, ".")); +} + +void _split(const std::string& eq, const mapshape_type& sh, mapshape_type& dx) { + dx.clear(); + for (size_t i = 0; i < sh.size(); ++i) { + dx.add(eq[i], mapshape_element(sh.at(eq[i]).first, i)); + } +} + +void _split(const std::string& eq, const std::vector& sh, mapshape_type& dx) { + dx.clear(); + for (size_t i = 0; i < sh.size(); ++i) { + dx.add(eq[i], mapshape_element(sh[i], i)); + } +} + +void _equation_split(const std::string& equation, + std::string& eqx, std::string& eqy, std::string& eqr) { + size_t comma = equation.find_first_of(","); + size_t dash = equation.find_first_of("-", comma); + eqx = equation.substr(0, comma); + eqy = equation.substr(comma + 1, dash - comma - 1); + eqr = equation.substr(dash+2, equation.size() - dash - 2); +} + +void _interpret(const mapshape_type& dx, const mapshape_type& dy, const std::string& eqr, + mapshape_type& shape, std::vector>& c_uni, + std::vector& c_trp, std::vector& c_sum) { + c_uni.clear(); + c_trp.clear(); + c_sum.clear(); + c_uni.reserve(eqr.size()); + c_trp.reserve(eqr.size()); + c_sum.reserve(eqr.size()); + for (char r: eqr) { + if (dx.has_key(r)) { + if (dy.has_key(r)) { + if (dx.at(r).first != dy.at(r).first) + throw std::runtime_error(MakeString( + "Dimension mismatch for letter ", r, " dx=", dx, " dy=", dy, ".")); + c_trp.push_back(r); + } + else + c_uni.push_back(std::pair(r, '#')); + } + else if (dy.has_key(r)) + c_uni.push_back(std::pair('#', r)); + else + throw std::runtime_error(MakeString( + "Unexpected letter ", r, " in result ", eqr, ".")); + } + for (size_t i = 0; i < dx.size(); ++i) { + char c = dx.key(i); + if (std::find(eqr.begin(), eqr.end(), c) == eqr.end()) { + if (!dy.has_key(c)) + throw std::runtime_error(MakeString( + "Unable to guess what to do with column ", c, " (left side).")); + if (dx.at(c).first != dy.at(c).first) + throw std::runtime_error(MakeString( + "Dimension mismatch for letter ", c, " dx=", dx, " dy=", dy, ".")); + c_sum.push_back(c); + } + } + for (size_t i = 0; i < dy.size(); ++i) { + char c = dy.key(i); + if (std::find(eqr.begin(), eqr.end(), c) == eqr.end() && !dx.has_key(c)) + throw std::runtime_error(MakeString( + "Unable to guess what to do with column ", c, " (right side).")); + } + shape.clear(); + for (size_t i = 0; i < eqr.size(); ++i) { + char r = eqr[i]; + if (std::find(c_trp.begin(), c_trp.end(), r) != c_trp.end()) + shape.add(r, mapshape_element(dx.at(r).first, i)); + else { + for (auto p: c_uni) { + if (p.first == r) { + shape.add(r, mapshape_element(dx.at(r).first, i)); + break; + } + if (p.second == r) { + shape.add(r, mapshape_element(dy.at(r).first, i)); + break; + } + } + } + } + if (shape.size() != eqr.size()) + throw std::runtime_error(MakeString( + "Unable to compute the output shape dx=", dx , "dy=", dy, " eqr=", eqr, " got shape=", shape, ".")); +} + +void _inc(const mapshape_type &d, mapshape_type& res) { + int64_t t = 1; + std::vector> temp; + temp.reserve(d.size()); + for (int i = (int)d.size()-1; i >= 0; --i) { + temp.push_back(std::pair( + d.key(i), mapshape_element(t, d.value(i).second))); + t *= d.value(i).first; + } + res.clear(); + for(auto it = temp.rbegin(); it != temp.rend(); ++it) + res.add(it->first, it->second); +} + +int64_t prod(const mapshape_type& seq) { + int64_t p = 1; + for (size_t i = 0; i < seq.size(); ++i) + p *= seq.value(i).first; + return p; +} + +void get_index(const mapshape_type &cd, const mapshape_type &shape, + const std::vector& index, char col_sum, + int64_t& ind, int64_t& out_inc) { + ind = 0; + for(size_t i = 0; i < shape.size(); ++i) { + if (cd.has_key(shape.key(i))) + ind += shape.value(i).first * index[i]; + } + out_inc = cd.at(col_sum).first; +} + +void get_incs(const mapshape_type &cd, const mapshape_type &shape, + std::vector& incs) { + incs.clear(); + incs.reserve(cd.size()); + for(size_t i = 0; i < shape.size(); ++i) + incs.push_back(cd.has_key(shape.key(i)) ? cd.at(shape.key(i)).first : 0); +} + +void mapshape2shape(const mapshape_type &shape, std::vector& out_shape) { + out_shape.clear(); + out_shape.reserve(shape.size()); + for(size_t i = 0; i < shape.size(); ++i) + out_shape.push_back(shape.value(i).first); +} + +void mapshape2shape(const mapshape_type &shape, std::vector& out_shape) { + out_shape.clear(); + out_shape.reserve(shape.size()); + for(size_t i = 0; i < shape.size(); ++i) + out_shape.push_back(static_cast(shape.value(i).first)); +} + + +template +py::array_t custom_einsum(const std::string& equation, + py::array_t x, + py::array_t y) { + + std::vector x_shape, y_shape; + arrayshape2vector(x_shape, x); + arrayshape2vector(y_shape, y); + + const NTYPE* x_data = x.data(); + const NTYPE* y_data = y.data(); + + std::string eqx, eqy, eqr; + _equation_split(equation, eqx, eqy, eqr); + _check_eq(eqx, x_shape); + _check_eq(eqy, y_shape); + mapshape_type dx, dy; + _split(eqx, x_shape, dx); + _split(eqy, y_shape, dy); + + mapshape_type shape; + std::vector> c_uni; + std::vector c_trp, c_sum; + _interpret(dx, dy, eqr, shape, c_uni, c_trp, c_sum); + + if (c_sum.size() != 1) + throw std::runtime_error(MakeString( + "More than one summation indices ", c_sum, " in equation ", equation, ".")); + + mapshape_type cdx, cdy; + _inc(dx, cdx); + _inc(dy, cdy); + int64_t full_size = prod(shape); + + std::vector z_vector(full_size); + NTYPE* z_data = z_vector.data(); + + // loop + std::vector shape_dims(shape.size()); + std::vector index(shape.size()); + for(size_t i = 0; i < shape.size(); ++i) { + shape_dims[i] = shape.value(i).first; + index[i] = 0; + } + + size_t len_index = index.size(); + int64_t loop_size = dx.at(c_sum[0]).first; + + int64_t i_left_loop, inc_left, i_right_loop, inc_right; + get_index(cdx, shape, index, c_sum[0], i_left_loop, inc_left); + get_index(cdy, shape, index, c_sum[0], i_right_loop, inc_right); + + std::vector left_incs, right_incs; + get_incs(cdx, shape, left_incs); + get_incs(cdy, shape, right_incs); + NTYPE add; + const NTYPE *xp, *yp; + NTYPE *zp; + NTYPE *z_end = z_data + full_size; + size_t pos; + int64_t i_loop; + + for(zp = z_data; zp != z_end; ++zp) { + + // summation + add = (NTYPE)0; + xp = x_data + i_left_loop; + yp = y_data + i_right_loop; + + for (i_loop = loop_size; i_loop != 0; xp += inc_left, yp += inc_right, --i_loop) { + add += *xp * *yp; + } + *zp = add; + + // increment + pos = len_index - 1; + ++index[pos]; + i_left_loop += left_incs[pos]; + i_right_loop += right_incs[pos]; + while (pos > 0 && index[pos] >= shape_dims[pos]) { + i_left_loop -= left_incs[pos] * index[pos]; + i_right_loop -= right_incs[pos] * index[pos]; + index[pos] = 0; + --pos; + ++index[pos]; + i_left_loop += left_incs[pos]; + i_right_loop += right_incs[pos]; + } + } + + std::vector z_shape; + std::vector strides; + + mapshape2shape(shape, z_shape); + shape2strides(z_shape, strides, (NTYPE)0); + + return py::array_t( + py::buffer_info( + &z_vector[0], + sizeof(NTYPE), + py::format_descriptor::format(), + z_shape.size(), + z_shape, /* shape of the matrix */ + strides /* strides for each axis */ + )); +} + + +py::array_t custom_einsum_float( + const std::string& equation, + py::array_t x, + py::array_t y) { + return custom_einsum(equation, x, y); +} + + +py::array_t custom_einsum_double( + const std::string& equation, + py::array_t x, + py::array_t y) { + return custom_einsum(equation, x, y); +} + + +py::array_t custom_einsum_int64( + const std::string& equation, + py::array_t x, + py::array_t y) { + return custom_einsum(equation, x, y); +} + + +py::array_t custom_einsum_int32( + const std::string& equation, + py::array_t x, + py::array_t y) { + return custom_einsum(equation, x, y); +} + +////////////// +// end: einsum +////////////// + + +#ifndef SKIP_PYTHON + +PYBIND11_MODULE(experimental_c, m) { + m.doc() = + #if defined(__APPLE__) + "C++ experimental implementations." + #else + R"pbdoc(C++ experimental implementations.)pbdoc" + #endif + ; + + m.def("custom_einsum_float", &custom_einsum_float, + R"pbdoc(Custom C++ implementation of operator *einsum* with float. +The function only works with contiguous arrays. +It does not any explicit transposes. It does not support +diagonal operator (repetition of the same letter). +See python's version :func:`custom_einsum `. +)pbdoc"); + + m.def("custom_einsum_double", &custom_einsum_double, + R"pbdoc(Custom C++ implementation of operator *einsum* with double. +The function only works with contiguous arrays. +It does not any explicit transposes. It does not support +diagonal operator (repetition of the same letter). +See python's version :func:`custom_einsum `. +)pbdoc"); + + m.def("custom_einsum_int32", &custom_einsum_int32, + R"pbdoc(Custom C++ implementation of operator *einsum* with int32. +The function only works with contiguous arrays. +It does not any explicit transposes. It does not support +diagonal operator (repetition of the same letter). +See python's version :func:`custom_einsum `. +)pbdoc"); + + m.def("custom_einsum_int64", &custom_einsum_int64, + R"pbdoc(Custom C++ implementation of operator *einsum* with int64. +The function only works with contiguous arrays. +It does not any explicit transposes. It does not support +diagonal operator (repetition of the same letter). +See python's version :func:`custom_einsum `. +)pbdoc"); +} + +#endif diff --git a/mlprodict/testing/experimental_c_helper.hpp b/mlprodict/testing/experimental_c_helper.hpp new file mode 100644 index 000000000..e562c1591 --- /dev/null +++ b/mlprodict/testing/experimental_c_helper.hpp @@ -0,0 +1,137 @@ +#pragma once + +#include +#include +#include +#include +#include // cout +#include + +#if defined(_WIN32) || defined(WIN32) + +inline bool _isnan_(float x) { return _isnanf(x); } +inline bool _isnan_(double x) { return _isnan(x); } + +#elif defined(__MACOSX__) || defined(__APPLE__) + +inline bool _isnan_(float x) { return (float)::isnan((double)x); } +inline bool _isnan_(double x) { return ::isnan(x); } + +#else + +// See https://stackoverflow.com/questions/2249110/how-do-i-make-a-portable-isnan-isinf-function +inline bool _isnan_(double x) { + union { uint64_t u; double f; } ieee754; + ieee754.f = x; + return ( (unsigned)(ieee754.u >> 32) & 0x7fffffff ) + + ( (unsigned)ieee754.u != 0 ) > 0x7ff00000; +} + +inline bool _isnan_(float x) { return _isnan_((double)x); } + +#endif + + +template +inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept { + ss << t; +} + +template <> +inline void MakeStringInternal(std::ostringstream& ss, const std::vector& t) noexcept { + for(auto it: t) { + ss << it << ","; + } +} + +template <> +inline void MakeStringInternal(std::ostringstream& ss, const std::vector& t) noexcept { + for(auto it: t) { + ss << it << ","; + } +} + +template <> +inline void MakeStringInternal(std::ostringstream& ss, const std::vector& t) noexcept { + for(auto it: t) { + ss << it << ","; + } +} + + +template <> +inline void MakeStringInternal(std::ostringstream& ss, const std::pair& t) noexcept { + ss << "(" << t.first << "," << t.second << ")"; +} + +template +inline void MakeStringInternal(std::ostringstream& ss, const T& t, const Args&... args) noexcept { + MakeStringInternal(ss, t); + MakeStringInternal(ss, args...); +} + +template +inline std::string MakeString(const Args&... args) { + std::ostringstream ss; + MakeStringInternal(ss, args...); + return std::string(ss.str()); +} + + +#define array2vector(vec, arr, dtype) { \ + if (arr.size() > 0) { \ + auto n = arr.size(); \ + auto p = (dtype*) arr.data(0); \ + vec = std::vector(p, p + n); \ + } \ +} + + +#define arrayshape2vector(vec, arr) { \ + if (arr.size() > 0) { \ + vec.resize(arr.ndim()); \ + for(size_t i = 0; i < vec.size(); ++i) \ + vec[i] = (int64_t) arr.shape(i); \ + } \ +} + + +template +NTYPE flattened_dimension(const std::vector& values) { + NTYPE r = 1; + for(auto it = values.begin(); it != values.end(); ++it) + r *= *it; + return r; +} + + +template +NTYPE flattened_dimension(const std::vector& values, int64_t first) { + NTYPE r = 1; + auto end = values.begin() + first; + for(auto it = values.begin(); it != end; ++it) + r *= *it; + return r; +} + + +template +void shape2strides(const std::vector& shape, + std::vector& strides, NTYPE cst) { + strides.resize(shape.size()); + strides[strides.size()-1] = static_cast(sizeof(NTYPE)); + for(ssize_t i = strides.size()-2; i >= 0; --i) + strides[i] = strides[i+1] * static_cast(shape[i+1]); +} + + +template +DIMTYPE SizeFromDimension(const std::vector& shape, size_t start, size_t end) { + DIMTYPE size = 1; + for (size_t i = start; i < end; i++) { + if (shape[i] < 0) + return -1; + size *= shape[i]; + } + return size; +} diff --git a/setup.py b/setup.py index 06a894d73..0b26979e6 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ project_var_name + ".asv_benchmark": ["*.json"], project_var_name + ".onnxrt.ops_cpu": ["*.cpp", "*.hpp"], project_var_name + ".onnxrt.validate.data": ["*.csv"], + project_var_name + ".testing": ["*.cpp", "*.hpp"], } ############ @@ -405,9 +406,24 @@ def write_version(): define_macros=define_macros, language='c++') + ext_experimental_c = Extension( + 'mlprodict.testing.experimental_c', + [os.path.join(root, 'mlprodict/testing/experimental_c.cpp')], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + include_dirs=[ + # Path to pybind11 headers + get_pybind_include(), + get_pybind_include(user=True), + os.path.join(root, 'mlprodict/testing') + ], + define_macros=define_macros, + language='c++') + ext_modules = [ ext_conv, ext_conv_transpose, + ext_experimental_c, ext_gather, ext_max_pool, ext_svm_classifier,