diff --git a/_unittests/ut__skl2onnx/test_sklearn_label_encoder_converter.py b/_unittests/ut__skl2onnx/test_sklearn_label_encoder_converter.py
index 42f3f33bb..558ee0399 100644
--- a/_unittests/ut__skl2onnx/test_sklearn_label_encoder_converter.py
+++ b/_unittests/ut__skl2onnx/test_sklearn_label_encoder_converter.py
@@ -50,7 +50,8 @@ def test_model_label_encoder_int(self):
model = LabelEncoder()
data = numpy.array([10, 3, 5, -34, 0], dtype=numpy.int64)
model.fit(data)
- for op in sorted(set([9, 10, 11, 12, 13, 14, TARGET_OPSET])): # opset=13, 14, ...
+ # opset=13, 14, ...
+ for op in sorted(set([9, 10, 11, 12, 13, 14, TARGET_OPSET])):
if op > TARGET_OPSET:
continue
with self.subTest(opset=op):
diff --git a/_unittests/ut_onnxrt/test_benchmark_replay.py b/_unittests/ut_onnxrt/test_benchmark_replay.py
index abf3b4cb8..684430f10 100644
--- a/_unittests/ut_onnxrt/test_benchmark_replay.py
+++ b/_unittests/ut_onnxrt/test_benchmark_replay.py
@@ -17,7 +17,8 @@ def test_benchmark_replay(self):
enumerate_benchmark_replay(temp, runtime='python')),
FileNotFoundError)
res = list(enumerate_validated_operator_opsets(
- 0, fLOG=None, models={"LogisticRegression"}, opset_min=14, # opset=13, 14, ...
+ # opset=13, 14, ...
+ 0, fLOG=None, models={"LogisticRegression"}, opset_min=14,
opset_max=14, benchmark=False, store_models=True, dump_all=True,
dump_folder=temp, filter_exp=lambda m, p: (
"64" not in p and "b-cl" in p and "dec" not in p)))
diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py
index a99c73c48..5b5dab74b 100644
--- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py
+++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py
@@ -2874,7 +2874,8 @@ def test_onnxt_runtime_reduce_prod(self):
def test_onnxt_runtime_reduce_sum(self):
X = numpy.array([[2, 1], [0, 1]], dtype=float)
- for opset in (10, 11, 12, 13, 14, get_opset_number_from_onnx()): # opset=13, 14, ...
+ # opset=13, 14, ...
+ for opset in (10, 11, 12, 13, 14, get_opset_number_from_onnx()):
if onnx_opset_version() < opset:
continue
if opset < 13:
@@ -2977,7 +2978,8 @@ def test_onnxt_runtime_reduce_sum_square(self):
def test_onnxt_runtime_reduce_sum_noop(self):
X = numpy.array([], dtype=float).reshape((2, 0))
- for opset in (13, 14, get_opset_number_from_onnx()): # opset=13, 14, ...
+ # opset=13, 14, ...
+ for opset in (13, 14, get_opset_number_from_onnx()):
if onnx_opset_version() < opset:
continue
cl = OnnxReduceSum
@@ -3191,7 +3193,8 @@ def test_onnxt_runtime_slice(self):
@wraplog()
def test_onnxt_runtime_slice_step_none(self):
- for opset in [13, 14, get_opset_number_from_onnx()]: # opset=13, 14, ...
+ # opset=13, 14, ...
+ for opset in [13, 14, get_opset_number_from_onnx()]:
if opset > get_opset_number_from_onnx():
continue
with self.subTest(opset=opset):
@@ -3211,7 +3214,8 @@ def test_onnxt_runtime_slice_step_none(self):
@wraplog()
def test_onnxt_runtime_split(self):
- for opset in [10, 11, 12, 13, 14, get_opset_number_from_onnx()]: # opset=13, 14, ...
+ # opset=13, 14, ...
+ for opset in [10, 11, 12, 13, 14, get_opset_number_from_onnx()]:
if opset > get_opset_number_from_onnx():
continue
with self.subTest(opset=opset):
@@ -3259,7 +3263,8 @@ def test_onnxt_runtime_sqrt(self):
@wraplog()
def test_onnxt_runtime_squeeze(self):
- for opset in [10, 11, 12, 13, 14, get_opset_number_from_onnx()]: # opset=13, 14, ...
+ # opset=13, 14, ...
+ for opset in [10, 11, 12, 13, 14, get_opset_number_from_onnx()]:
if opset > get_opset_number_from_onnx():
continue
with self.subTest(opset=opset):
@@ -3444,7 +3449,8 @@ def test_onnxt_runtime_transpose(self):
@wraplog()
def test_onnxt_runtime_unsqueeze(self):
- for opset in [10, 11, 12, 13, 14, get_opset_number_from_onnx()]: # opset=13, 14, ...
+ # opset=13, 14, ...
+ for opset in [10, 11, 12, 13, 14, get_opset_number_from_onnx()]:
if opset > get_opset_number_from_onnx():
continue
with self.subTest(opset=opset):
diff --git a/_unittests/ut_testing/test_custom_add.py b/_unittests/ut_testing/test_custom_add.py
new file mode 100644
index 000000000..b97b0b757
--- /dev/null
+++ b/_unittests/ut_testing/test_custom_add.py
@@ -0,0 +1,50 @@
+"""
+@brief test log(time=8s)
+"""
+import unittest
+import numpy
+from pyquickhelper.pycode import ExtTestCase
+from mlprodict.testing.experimental_c import ( # pylint: disable=E0611
+ BroadcastMatrixAddLeftInplaceDouble,
+ BroadcastMatrixAddLeftInplaceFloat,
+ BroadcastMatrixAddLeftInplaceInt64)
+
+
+class TestCustomAdd(ExtTestCase):
+
+ add_dtypes = {
+ numpy.float64: BroadcastMatrixAddLeftInplaceDouble,
+ numpy.float32: BroadcastMatrixAddLeftInplaceFloat,
+ numpy.int64: BroadcastMatrixAddLeftInplaceInt64
+ }
+
+ def _common_broadcast_matrix(self, dt):
+ with self.subTest(dtype=dt):
+ fct = TestCustomAdd.add_dtypes[dt]
+
+ m1 = numpy.array([1, 2, 3, 4, 5, 6], dtype=dt).reshape((-1, 2))
+ m2 = numpy.array([1, 2], dtype=dt).reshape((1, 2))
+ m3 = m1 + m2
+ fct(m1, m2)
+ self.assertEqualArray(m3, m1)
+
+ m1 = numpy.array([1, 2, 3, 4, 5, 6], dtype=dt).reshape((-1, 3))
+ m2 = numpy.array([1, 2], dtype=dt).reshape((2, 1))
+ m3 = m1 + m2
+ fct(m1, m2)
+ self.assertEqualArray(m3, m1)
+
+ m1 = numpy.array([1, 2, 3, 4, 5, 6], dtype=dt).reshape((-1, 3))
+ m2 = numpy.array([1, 2], dtype=dt).reshape((2, ))
+ m3 = m1 + m2.reshape((2, 1))
+ fct(m1, m2)
+ self.assertEqualArray(m3, m1)
+
+ def test_broadcast_matrix(self):
+ for dt in [numpy.float64, numpy.float32, numpy.int64]:
+ self._common_broadcast_matrix(dt)
+
+
+if __name__ == "__main__":
+ # TestEinsum().test_np_test_broadcasting_dot_cases1()
+ unittest.main()
diff --git a/bin/debug/debug.cpp b/bin/debug/debug.cpp
index 19f5def72..65ca9caab 100644
--- a/bin/debug/debug.cpp
+++ b/bin/debug/debug.cpp
@@ -5,6 +5,7 @@
#include "op_qlinear_conv_.hpp"
#include "op_qlinear_cpp_qgemm_tester_.hpp"
#include "op_qlinear_cpp_tester_.hpp"
+#include "experimental_c.h"
void test_qlinear_conv1() {
@@ -100,6 +101,9 @@ void test_qlinear_conv2(bool random) {
}
int main() {
+ experimental_ut_add();
+ experimental_ut_einsum();
+ experimental_ut_reduce();
TestQGemm0();
TestQGemm1();
test_qlinear_conv2(false);
diff --git a/bin/debug/debug.vcxproj b/bin/debug/debug.vcxproj
index 9152705c4..d2951c3a4 100644
--- a/bin/debug/debug.vcxproj
+++ b/bin/debug/debug.vcxproj
@@ -116,7 +116,7 @@
true
_DEBUG;_CONSOLE;%(PreprocessorDefinitions);SKIP_PYTHON
true
- ../../mlprodict/onnxrt/ops_cpu
+ ../../mlprodict/onnxrt/ops_cpu;../../mlprodict/testing
Console
@@ -150,6 +150,16 @@
+
+
+
+
+
+
+
+
+
+
diff --git a/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.cpp b/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.cpp
index 6535d1492..228a1441e 100644
--- a/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.cpp
+++ b/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.cpp
@@ -1,3 +1,6 @@
+// Inspired from
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/nn/qlinearconv.cc.
+
#include "op_conv_matrices_.hpp"
diff --git a/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp b/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp
index 7c83ed48e..95d91da59 100644
--- a/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp
+++ b/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp
@@ -1,7 +1,7 @@
#pragma once
// Inspired from
-// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc.
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/nn/qlinearconv.cc.
#if !defined(_CRT_SECURE_NO_WARNINGS)
#define _CRT_SECURE_NO_WARNINGS
diff --git a/mlprodict/onnxrt/ops_cpu/op_qlinear_conv_.cpp b/mlprodict/onnxrt/ops_cpu/op_qlinear_conv_.cpp
index 86f7dda1a..7b1ed0432 100644
--- a/mlprodict/onnxrt/ops_cpu/op_qlinear_conv_.cpp
+++ b/mlprodict/onnxrt/ops_cpu/op_qlinear_conv_.cpp
@@ -1,11 +1,10 @@
// Inspired from
-// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc.
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/nn/qlinearconv.cc.
#if !defined(_CRT_SECURE_NO_WARNINGS)
#define _CRT_SECURE_NO_WARNINGS
#endif
-
#include "op_qlinear_conv_.hpp"
#include "op_qlinear_cpp_tester_.hpp"
#include "op_qlinear_cpp_qgemm_tester_.hpp"
@@ -20,7 +19,6 @@
#undef max
#endif
-
class RuntimeTesterQLinearConv : public RuntimeTester {
public:
RuntimeTesterQLinearConv(const char* op_name, int opset = 13) : RuntimeTester(op_name, opset) {}
diff --git a/mlprodict/onnxrt/ops_cpu/op_qlinear_conv_.hpp b/mlprodict/onnxrt/ops_cpu/op_qlinear_conv_.hpp
index 04146da63..a23a162e3 100644
--- a/mlprodict/onnxrt/ops_cpu/op_qlinear_conv_.hpp
+++ b/mlprodict/onnxrt/ops_cpu/op_qlinear_conv_.hpp
@@ -1,7 +1,7 @@
#pragma once
// Inspired from
-// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc.
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/nn/qlinearconv.cc.
#if !defined(_CRT_SECURE_NO_WARNINGS)
#define _CRT_SECURE_NO_WARNINGS
@@ -49,8 +49,7 @@ inline uint32_t BitsOfFp32(float FloatValue) {
*/
template
void RequantizeOutput(
- const int32_t* Input, T* Output, const int32_t* Bias,
- size_t M, size_t N,
+ const int32_t* Input, T* Output, const int32_t* Bias, size_t M, size_t N,
const float* Scale, bool PerColumnScale, T ZeroPoint) {
const float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale;
const float MinimumValue = float(0 - ZeroPoint);
diff --git a/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_qgemm_tester_.cpp b/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_qgemm_tester_.cpp
index 16a12e959..35b43c130 100644
--- a/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_qgemm_tester_.cpp
+++ b/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_qgemm_tester_.cpp
@@ -1,5 +1,5 @@
// Inspired from
-// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc.
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc.
#if !defined(_CRT_SECURE_NO_WARNINGS)
#define _CRT_SECURE_NO_WARNINGS
diff --git a/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_qgemm_tester_.hpp b/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_qgemm_tester_.hpp
index 03766597d..760202d2c 100644
--- a/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_qgemm_tester_.hpp
+++ b/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_qgemm_tester_.hpp
@@ -1,7 +1,7 @@
#pragma once
// Inspired from
-// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc.
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc.
#if !defined(_CRT_SECURE_NO_WARNINGS)
#define _CRT_SECURE_NO_WARNINGS
diff --git a/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_tester_.hpp b/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_tester_.hpp
index bafa226dc..a1c5109e7 100644
--- a/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_tester_.hpp
+++ b/mlprodict/onnxrt/ops_cpu/op_qlinear_cpp_tester_.hpp
@@ -1,7 +1,7 @@
#pragma once
// Inspired from
-// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc.
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc.
#if !defined(_CRT_SECURE_NO_WARNINGS)
#define _CRT_SECURE_NO_WARNINGS
diff --git a/mlprodict/testing/experimental_c.cpp b/mlprodict/testing/experimental_c.cpp
index 77af75947..0205ca24e 100644
--- a/mlprodict/testing/experimental_c.cpp
+++ b/mlprodict/testing/experimental_c.cpp
@@ -1,630 +1,92 @@
-// Inspired from
-// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_regressor.cc.
-
-#if !defined(_CRT_SECURE_NO_WARNINGS)
-#define _CRT_SECURE_NO_WARNINGS
-#endif
-
-#include
-#include
-#include
-#include
-#include
-#include
-
-#ifndef SKIP_PYTHON
-//#include
-#include
-#include
-#include
-//#include
-
-#if USE_OPENMP
-#include
-#endif
-
-namespace py = pybind11;
-#endif
-
+#include "experimental_c.h"
#include "experimental_c_helper.hpp"
-
-#include
-#include
-#include
-
-
-////////////////
-// begin: einsum
-////////////////
-
-typedef std::pair mapshape_element;
-
-class mapshape_type {
- protected:
- std::map container;
- std::vector order;
- public:
- mapshape_type() : container() {}
- inline size_t size() const { return container.size(); }
- inline const mapshape_element& at(const char& c) const { return container.at(c); }
- inline const mapshape_element& value(size_t i) const { return container.at(order[i]); }
- inline char key(size_t i) const { return order[i]; }
- void clear() {
- container.clear();
- order.clear();
- }
- void add(char c, const mapshape_element& el) {
- container[c] = el;
- order.push_back(c);
- }
- bool has_key(const char& key) const {
- return container.find(key) != container.end();
- }
-};
-
-template <>
-inline void MakeStringInternal(std::ostringstream& ss, const mapshape_type& t) noexcept {
- for(size_t i = 0; i < t.size(); ++i) {
- ss << t.key(i) << ":" << t.value(i).first << "," << t.value(i).second << " ";
- }
-}
-
-
-template
-void _check_eq(const std::string&eq, const TYPE& sh) {
- if (eq.size() != sh.size())
- throw std::runtime_error(MakeString(
- "Unable to map equation ", eq, " to shape ", sh, "."));
-}
-
-void _split(const std::string& eq, const mapshape_type& sh, mapshape_type& dx) {
- dx.clear();
- for (size_t i = 0; i < sh.size(); ++i) {
- dx.add(eq[i], mapshape_element(sh.at(eq[i]).first, i));
- }
-}
-
-void _split(const std::string& eq, const std::vector& sh, mapshape_type& dx) {
- dx.clear();
- for (size_t i = 0; i < sh.size(); ++i) {
- dx.add(eq[i], mapshape_element(sh[i], i));
- }
-}
-
-void _equation_split(const std::string& equation,
- std::string& eqx, std::string& eqy, std::string& eqr) {
- size_t comma = equation.find_first_of(",");
- size_t dash = equation.find_first_of("-", comma);
- eqx = equation.substr(0, comma);
- eqy = equation.substr(comma + 1, dash - comma - 1);
- eqr = equation.substr(dash+2, equation.size() - dash - 2);
-}
-
-void _interpret(const mapshape_type& dx, const mapshape_type& dy, const std::string& eqr,
- mapshape_type& shape, std::vector>& c_uni,
- std::vector& c_trp, std::vector& c_sum) {
- c_uni.clear();
- c_trp.clear();
- c_sum.clear();
- c_uni.reserve(eqr.size());
- c_trp.reserve(eqr.size());
- c_sum.reserve(eqr.size());
- for (char r: eqr) {
- if (dx.has_key(r)) {
- if (dy.has_key(r)) {
- if (dx.at(r).first != dy.at(r).first)
- throw std::runtime_error(MakeString(
- "Dimension mismatch for letter ", r, " dx=", dx, " dy=", dy, "."));
- c_trp.push_back(r);
- }
- else
- c_uni.push_back(std::pair(r, '#'));
- }
- else if (dy.has_key(r))
- c_uni.push_back(std::pair('#', r));
- else
- throw std::runtime_error(MakeString(
- "Unexpected letter ", r, " in result ", eqr, "."));
- }
- for (size_t i = 0; i < dx.size(); ++i) {
- char c = dx.key(i);
- if (std::find(eqr.begin(), eqr.end(), c) == eqr.end()) {
- if (!dy.has_key(c))
- throw std::runtime_error(MakeString(
- "Unable to guess what to do with column ", c, " (left side)."));
- if (dx.at(c).first != dy.at(c).first)
- throw std::runtime_error(MakeString(
- "Dimension mismatch for letter ", c, " dx=", dx, " dy=", dy, "."));
- c_sum.push_back(c);
- }
- }
- for (size_t i = 0; i < dy.size(); ++i) {
- char c = dy.key(i);
- if (std::find(eqr.begin(), eqr.end(), c) == eqr.end() && !dx.has_key(c))
- throw std::runtime_error(MakeString(
- "Unable to guess what to do with column ", c, " (right side)."));
- }
- shape.clear();
- for (size_t i = 0; i < eqr.size(); ++i) {
- char r = eqr[i];
- if (std::find(c_trp.begin(), c_trp.end(), r) != c_trp.end())
- shape.add(r, mapshape_element(dx.at(r).first, i));
- else {
- for (auto p: c_uni) {
- if (p.first == r) {
- shape.add(r, mapshape_element(dx.at(r).first, i));
- break;
- }
- if (p.second == r) {
- shape.add(r, mapshape_element(dy.at(r).first, i));
- break;
- }
- }
- }
- }
- if (shape.size() != eqr.size())
- throw std::runtime_error(MakeString(
- "Unable to compute the output shape dx=", dx , "dy=", dy, " eqr=", eqr, " got shape=", shape, "."));
-}
-
-void _inc(const mapshape_type &d, mapshape_type& res) {
- int64_t t = 1;
- std::vector> temp;
- temp.reserve(d.size());
- for (int i = (int)d.size()-1; i >= 0; --i) {
- temp.push_back(std::pair(
- d.key(i), mapshape_element(t, d.value(i).second)));
- t *= d.value(i).first;
- }
- res.clear();
- for(auto it = temp.rbegin(); it != temp.rend(); ++it)
- res.add(it->first, it->second);
-}
-
-int64_t prod(const mapshape_type& seq) {
- int64_t p = 1;
- for (size_t i = 0; i < seq.size(); ++i)
- p *= seq.value(i).first;
- return p;
-}
-
-int64_t get_index(const std::vector &incs, const std::vector& index) {
- int64_t ind = 0;
- for(size_t i = 0; i < index.size(); ++i)
- ind += incs[i] * index[i];
- return ind;
-}
-
-void get_incs(const mapshape_type &cd, const mapshape_type &shape,
- std::vector& incs) {
- incs.clear();
- incs.reserve(cd.size());
- for(size_t i = 0; i < shape.size(); ++i)
- incs.push_back(cd.has_key(shape.key(i)) ? cd.at(shape.key(i)).first : 0);
-}
-
-void mapshape2shape(const mapshape_type &shape, std::vector& out_shape) {
- out_shape.clear();
- out_shape.reserve(shape.size());
- for(size_t i = 0; i < shape.size(); ++i)
- out_shape.push_back(shape.value(i).first);
-}
-
-void mapshape2shape(const mapshape_type &shape, std::vector& out_shape) {
- out_shape.clear();
- out_shape.reserve(shape.size());
- for(size_t i = 0; i < shape.size(); ++i)
- out_shape.push_back(static_cast(shape.value(i).first));
-}
-
-template
-NTYPE vector_dot_product_pointer16(const NTYPE *p1, const NTYPE *p2, size_t size) {
- NTYPE sum = 0;
- for (; size != 0; ++p1, ++p2, --size)
- sum += *p1 * *p2;
- return sum;
-}
-
-std::string code_optimisation() {
- #if USE_OPENMP
- std::string omp = MakeString("omp=", omp_get_num_procs());
- #else
- std::string omp = MakeString("th=", 1);
- #endif
- #if defined(_CMP_EQ_OQ) // defined in immintrin
- return MakeString("AVX-", omp);
- #else
- return MakeString("SSE-", omp);
- #endif
-}
-
-template <>
-float vector_dot_product_pointer16(const float *p1, const float *p2, size_t size) {
- float sum = 0;
- #if defined(__AVX__)
- if (size > 8) {
- __m256 r256 = _mm256_setzero_ps();
- for (; size > 8; p1 += 8, p2 += 8, size -= 8)
- r256 = _mm256_add_ps(r256, _mm256_mul_ps(_mm256_load_ps(p1), _mm256_load_ps(p2)));
- __m128 c1, c2, r1;
- c1 = _mm256_extractf128_ps(r256, 1);
- c2 = _mm256_extractf128_ps(r256, 0);
- r1 = _mm_add_ps(c1, c2);
- c1 = _mm_shuffle_ps(r1, r1, _MM_SHUFFLE(2, 3, 0, 1));
- c2 = _mm_add_ps(r1, c1);
- c1 = _mm_movehl_ps(c1, c2);
- c2 = _mm_add_ss(c2, c1);
- sum += _mm_cvtss_f32(c2);
- }
- #else
- if (size > 4) {
- __m128 c1, c2;
- __m128 r1 = _mm_setzero_ps();
- for (; size > 4; p1 += 4, p2 += 4, size -= 4)
- r1 = _mm_add_ps(r1, _mm_mul_ps(_mm_load_ps(p1), _mm_load_ps(p2)));
- c1 = _mm_shuffle_ps(r1, r1, _MM_SHUFFLE(2, 3, 0, 1));
- c2 = _mm_add_ps(r1, c1);
- c1 = _mm_movehl_ps(c1, c2);
- c2 = _mm_add_ss(c2, c1);
- sum += _mm_cvtss_f32(c2);
- }
- #endif
- for (; size != 0; ++p1, ++p2, --size)
- sum += *p1 * *p2;
- return sum;
-}
-
-template
-NTYPE vector_dot_product_pointer_stride(const NTYPE *xp, const NTYPE *yp, size_t size,
- int64_t inc_left, int64_t inc_right) {
- NTYPE sum = (NTYPE)0;
- for (int64_t i_loop = size; i_loop != 0; xp += inc_left, yp += inc_right, --i_loop)
- sum += *xp * *yp;
- return sum;
-}
-
-void set_index(int64_t begin, const std::vector& shape_dims, std::vector& index) {
- for(size_t i = shape_dims.size()-1; i > 0; --i) {
- index[i] = begin % shape_dims[i];
- begin -= index[i];
- begin /= shape_dims[i];
- }
- index[0] = begin;
-}
-
-
-template
-void custom_einsum_matmul(const NTYPE* x_data, const NTYPE* y_data,
- int64_t loop_size,
- const mapshape_type& cdx, const mapshape_type& cdy,
- const mapshape_type& shape,
- const std::vector& left_incs,
- const std::vector& right_incs,
- NTYPE* z_data, int64_t begin, int64_t end,
- char col_sum) {
- const NTYPE *xp, *yp;
- NTYPE *zp;
- size_t pos;
- NTYPE *z_end = z_data + end;
- size_t len_index = shape.size();
-
- std::vector shape_dims(len_index);
- for(size_t i = 0; i < len_index; ++i)
- shape_dims[i] = shape.value(i).first;
-
- std::vector index(len_index);
- int64_t i_left_loop, inc_left, i_right_loop, inc_right;
- set_index(begin, shape_dims, index);
- i_left_loop = get_index(left_incs, index);
- i_right_loop = get_index(right_incs, index);
- inc_left = cdx.at(col_sum).first;
- inc_right = cdy.at(col_sum).first;
-
- for(zp = z_data + begin; zp != z_end; ++zp) {
- // summation
- xp = x_data + i_left_loop;
- yp = y_data + i_right_loop;
-
- if (inc_left == 1 && inc_right == 1) {
- *zp = vector_dot_product_pointer16(xp, yp, loop_size);
- }
- else {
- *zp = vector_dot_product_pointer_stride(xp, yp, loop_size, inc_left, inc_right);
- }
-
- // increment
- pos = len_index - 1;
- ++index[pos];
- i_left_loop += left_incs[pos];
- i_right_loop += right_incs[pos];
- while (pos > 0 && index[pos] >= shape_dims[pos]) {
- i_left_loop -= left_incs[pos] * index[pos];
- i_right_loop -= right_incs[pos] * index[pos];
- index[pos] = 0;
- --pos;
- ++index[pos];
- i_left_loop += left_incs[pos];
- i_right_loop += right_incs[pos];
- }
- }
-}
-
-
-template
-py::array_t custom_einsum(const std::string& equation,
- py::array_t x,
- py::array_t y,
- int nthread) {
-
- std::vector x_shape, y_shape;
- arrayshape2vector(x_shape, x);
- arrayshape2vector(y_shape, y);
-
- const NTYPE* x_data = x.data();
- const NTYPE* y_data = y.data();
-
+#include "experimental_c_einsum.h"
+#include "experimental_c_einsum.hpp"
+#include "experimental_c_reduce.h"
+#include "experimental_c_reduce.hpp"
+#include "experimental_c_add.h"
+#include "experimental_c_add.hpp"
+
+void experimental_ut_einsum() {
+ std::vector v{ 1, 2, 3 };
+ vector_dot_product_pointer16(v.data(), v.data(), 3);
+
+ std::string equation = "ij,jk->ik";
std::string eqx, eqy, eqr;
_equation_split(equation, eqx, eqy, eqr);
- _check_eq(eqx, x_shape);
- _check_eq(eqy, y_shape);
- mapshape_type dx, dy;
- _split(eqx, x_shape, dx);
- _split(eqy, y_shape, dy);
-
- mapshape_type shape;
- std::vector> c_uni;
- std::vector c_trp, c_sum;
- _interpret(dx, dy, eqr, shape, c_uni, c_trp, c_sum);
-
- if (c_sum.size() != 1)
- throw std::runtime_error(MakeString(
- "More than one summation indices ", c_sum, " in equation ", equation, "."));
-
- mapshape_type cdx, cdy;
- _inc(dx, cdx);
- _inc(dy, cdy);
- int64_t full_size = prod(shape);
-
- std::vector z_vector(full_size);
- NTYPE* z_data = z_vector.data();
-
- // loop
- int64_t loop_size = dx.at(c_sum[0]).first;
-
- std::vector left_incs, right_incs;
- get_incs(cdx, shape, left_incs);
- get_incs(cdy, shape, right_incs);
-
- #if USE_OPENMP
- if (nthread == 1) {
- #endif
- custom_einsum_matmul(x_data, y_data, loop_size,
- cdx, cdy, shape,
- left_incs, right_incs, z_data,
- 0 /*begin*/, full_size /*end*/,
- c_sum[0]);
- #if USE_OPENMP
- }
- else {
- if (nthread > 1)
- omp_set_num_threads(nthread);
- else
- nthread = omp_get_num_procs();
- int N = nthread * 4;
- int64_t h = full_size / N;
- if (h == 0) {
- h = full_size;
- N = 1;
- }
-
- #pragma omp parallel for
- for(int i = 0; i < N; ++i) {
- int64_t begin = h * i;
- int64_t end = (i == N-1) ? full_size : begin + h;
- custom_einsum_matmul(x_data, y_data, loop_size,
- cdx, cdy, shape,
- left_incs, right_incs, z_data,
- begin /*begin*/, end /*end*/,
- c_sum[0]);
- }
- }
- #endif
-
- std::vector z_shape;
- std::vector strides;
-
- mapshape2shape(shape, z_shape);
- shape2strides(z_shape, strides, (NTYPE)0);
-
- return py::array_t(
- py::buffer_info(
- &z_vector[0],
- sizeof(NTYPE),
- py::format_descriptor::format(),
- z_shape.size(),
- z_shape, /* shape of the matrix */
- strides /* strides for each axis */
- ));
}
-
-py::array_t custom_einsum_float(
- const std::string& equation,
- py::array_t x,
- py::array_t y,
- int nthread) {
- return custom_einsum(equation, x, y, nthread);
-}
-
-
-py::array_t custom_einsum_double(
- const std::string& equation,
- py::array_t x,
- py::array_t y,
- int nthread) {
- return custom_einsum(equation, x, y, nthread);
-}
-
-
-py::array_t custom_einsum_int64(
- const std::string& equation,
- py::array_t x,
- py::array_t y,
- int nthread) {
- return custom_einsum(equation, x, y, nthread);
-}
-
-
-py::array_t custom_einsum_int32(
- const std::string& equation,
- py::array_t x,
- py::array_t y,
- int nthread) {
- return custom_einsum(equation, x, y, nthread);
-}
-
-//////////////
-// end: einsum
-//////////////
-
-
-////////////////
-// begin: reduce
-////////////////
-
-template
-void vector_add_pointer(NTYPE *acc, const NTYPE *x, size_t size) {
- for (; size != 0; ++acc, ++x, --size)
- *acc += *x;
-}
-
-
-template <>
-void vector_add_pointer(float *acc, const float *x, size_t size) {
- // _mm_store_ps fails if acc not aligned.
- // _mm_storeu_ps does not need alignment.
- #if defined(__AVX__)
- if (size > 8) {
- for (; size > 8; acc += 8, x += 8, size -= 8) {
- _mm256_storeu_ps(acc, _mm256_add_ps(_mm256_loadu_ps(acc), _mm256_loadu_ps(x)));
- }
- }
- #else
- if (size > 4) {
- for (; size > 4; acc += 4, x += 4, size -= 4) {
- _mm_storeu_ps(acc, _mm_add_ps(_mm_loadu_ps(acc), _mm_loadu_ps(x)));
- }
- }
- #endif
- for (; size != 0; ++acc, ++x, --size)
- *acc += *x;
+void experimental_ut_reduce() {
}
+void experimental_ut_add() {
+ TensorShape shape1(4);
+ shape1.p_dims[0] = 1;
+ shape1.p_dims[1] = 2;
+ shape1.p_dims[2] = 5;
+ shape1.p_dims[3] = 3;
-// This function assumes x is a 2D matrix to be reduced on the first axis.
-template
-py::array_t custom_reducesum_rk(py::array_t x,
- int nthread) {
- std::vector x_shape;
- arrayshape2vector(x_shape, x);
- if (x_shape.size() != 2)
- throw std::runtime_error("Input array must have two dimensions.");
- if (flattened_dimension(x_shape) == 0)
- throw std::runtime_error("Input array must not be empty.");
-
- int64_t N = x_shape[1];
- std::vector y_vector(N);
- // int64_t Nred = x_shape[0];
- const NTYPE* x_data = x.data();
- // const NTYPE* x_data_end = x_data + x_shape[0] * x_shape[1];
- NTYPE* y_data = y_vector.data();
+ TensorShape shape2(3);
+ shape2.p_dims[0] = 1;
+ shape2.p_dims[1] = 1;
+ shape2.p_dims[2] = 5;
- #if USE_OPENMP
- if (nthread == 1 || N <= nthread * 2) {
- #endif
- int64_t n_rows = x_shape[0];
- // NTYPE *y_data_end = y_data + N;
- memcpy(y_data, x_data, N * sizeof(NTYPE));
- for(int64_t row = 1; row < n_rows; ++row) {
- vector_add_pointer(y_data, x_data + row * N, N);
- }
- #if USE_OPENMP
- }
- else {
- if (nthread > 1)
- omp_set_num_threads(nthread);
- else
- nthread = omp_get_num_procs();
+ if (!shape1.right_broadcast(&shape2))
+ throw std::invalid_argument("experimental_ut_add 1");
+ if (shape2.right_broadcast(&shape1))
+ throw std::invalid_argument("experimental_ut_add 2");
- int64_t batch_size = N / nthread / 2;
- int64_t n_rows = x_shape[0];
- batch_size = batch_size < 4 ? 4 : batch_size;
- batch_size = batch_size > 1024 ? 1024 : batch_size;
- int64_t batch = N / batch_size + (N % batch_size > 0 ? 1 : 0);
- memcpy(y_data, x_data, N * sizeof(NTYPE));
+ Tensor t1(&shape1);
+ Tensor t2(&shape2);
+ Tensor::type_index n = shape1.Size();
+ for (Tensor::type_index i = 0; i < n; ++i)
+ t1.p_values[i] = static_cast::type_value>(i + 1);
+ n = shape2.Size();
+ for (Tensor::type_index i = 0; i < n; ++i)
+ t2.p_values[i] = static_cast::type_value>(1);
- #pragma omp parallel for
- for (int64_t b = 0; b < batch; ++b) {
- int64_t begin = batch_size * b;
- int64_t end = begin + batch_size < N ? begin + batch_size : N;
- for(int64_t row = 1; row < n_rows; ++row) {
- vector_add_pointer(y_data + begin, x_data + row * N + begin, end - begin);
- }
- }
+ BroadcastMatrixAddLeftInplace(&t1, &t2);
+ n = shape1.Size();
+ for (Tensor::type_index i = 0; i < n; ++i)
+ std::cout << t1.p_values[i] << ", ";
+ for (Tensor::type_index i = 0; i < n; ++i) {
+ if (t1.p_values[i] != static_cast::type_value>(i + 2))
+ throw std::invalid_argument(MakeString("discrepency:", t1.p_values[i], "!=", i + 2));
}
- #endif
-
- std::vector y_shape{N};
- std::vector strides;
- shape2strides(y_shape, strides, (NTYPE)0);
-
- return py::array_t(
- py::buffer_info(
- &y_vector[0],
- sizeof(NTYPE),
- py::format_descriptor::format(),
- y_shape.size(),
- y_shape, /* shape of the matrix */
- strides /* strides for each axis */
- ));
-}
-
-
-py::array_t custom_reducesum_rk_float(py::array_t x,
- int nthread) {
- return custom_reducesum_rk(x, nthread);
-}
-
-py::array_t custom_reducesum_rk_double(py::array_t x,
- int nthread) {
- return custom_reducesum_rk(x, nthread);
+ TensorShape sh1(3, shape1.p_dims);
+ Tensor v1(&sh1, t1.p_values);
}
-
-/////////////////
-// end: reducesum
-/////////////////
-
-
#ifndef SKIP_PYTHON
PYBIND11_MODULE(experimental_c, m) {
- m.doc() =
- #if defined(__APPLE__)
- "C++ experimental implementations."
- #else
- R"pbdoc(C++ experimental implementations.)pbdoc"
- #endif
- ;
+ m.doc() =
+#if defined(__APPLE__)
+ "C++ experimental implementations."
+#else
+ R"pbdoc(C++ experimental implementations.)pbdoc"
+#endif
+ ;
+
+ m.def("experimental_ut_reduce", &experimental_ut_reduce, R"pbdoc(C++ unit test for reduce)pbdoc");
+ m.def("experimental_ut_add", &experimental_ut_add, R"pbdoc(C++ unit test for add)pbdoc");
+ m.def("experimental_ut_einsum", &experimental_ut_einsum, R"pbdoc(C++ unit test for einsum)pbdoc");
+
+ m.def("BroadcastMatrixAddLeftInplaceInt64", &BroadcastMatrixAddLeftInplaceInt64,
+ R"pbdoc(Inplace addition, does X += Y. The function only allows broadcast in one way.)pbdoc");
+ m.def("BroadcastMatrixAddLeftInplaceFloat", &BroadcastMatrixAddLeftInplaceFloat,
+ R"pbdoc(Inplace addition, does X += Y. The function only allows broadcast in one way.)pbdoc");
+ m.def("BroadcastMatrixAddLeftInplaceDouble", &BroadcastMatrixAddLeftInplaceDouble,
+ R"pbdoc(Inplace addition, does X += Y. The function only allows broadcast in one way.)pbdoc");
m.def("code_optimisation", &code_optimisation,
- R"pbdoc(Returns a string giving some insights about optimisations.)pbdoc");
+ R"pbdoc(Returns a string giving some insights about optimisations.)pbdoc");
m.def("custom_einsum_float",
- &custom_einsum_float,
- py::arg("equation"), py::arg("x"), py::arg("y"), py::arg("nthread") = 0,
- R"pbdoc(Custom C++ implementation of operator *einsum* with float.
+ &custom_einsum_float,
+ py::arg("equation"), py::arg("x"), py::arg("y"), py::arg("nthread") = 0,
+ R"pbdoc(Custom C++ implementation of operator *einsum* with float.
The function only works with contiguous arrays.
It does not any explicit transposes. It does not support
diagonal operator (repetition of the same letter).
@@ -632,9 +94,9 @@ See python's version :func:`custom_einsum
+struct BroadcastIteratorRight {
+ typedef T1 type_value1;
+ typedef T2 type_value2;
+ typedef TS type_shape_value;
+ typedef N type_index;
+
+ const TensorShape* p_shape1;
+ const TensorShape* p_shape2;
+ T1* p1_;
+ const T2* p2_;
+
+ T1* p1_end;
+ TS* p_cum_shape2;
+
+ TS* p_index1_;
+ N last;
+
+ BroadcastIteratorRight(const TensorShape* shape1, const TensorShape* shape2, T1* p1, const T2* p2);
+ ~BroadcastIteratorRight();
+
+ bool end();
+ void next();
+};
+
+template
+void BroadcastMatrixAddLeftInplace(Tensor* X, const Tensor* Y);
+
+void experimental_ut_add();
diff --git a/mlprodict/testing/experimental_c_add.hpp b/mlprodict/testing/experimental_c_add.hpp
new file mode 100644
index 000000000..924f76430
--- /dev/null
+++ b/mlprodict/testing/experimental_c_add.hpp
@@ -0,0 +1,102 @@
+#pragma once
+
+// Inspired from
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_regressor.cc.
+
+#include "experimental_c_add.h"
+
+template
+BroadcastIteratorRight::BroadcastIteratorRight(
+ const TensorShape* shape1, const TensorShape* shape2, T1* p1, const T2* p2) {
+ p1_ = p1;
+ p2_ = p2;
+ p_shape1 = shape1;
+ p_shape2 = shape2;
+ if (!p_shape1->right_broadcast(p_shape2))
+ throw std::runtime_error("Cannot broascast tensor 2 with tensor 1.");
+
+ last = p_shape1->n_dims;
+ p_cum_shape2 = new TS[last];
+ p_index1_ = new TS[last];
+ p1_end = p1_ + p_shape1->Size();
+
+ p_cum_shape2[last - 1] = 1;
+ for (N i = 1; i < last; ++i) {
+ p_index1_[i] = 0;
+ p_cum_shape2[last - i - 1] = p_cum_shape2[last - i] * (
+ last - i < p_shape2->n_dims ? p_shape2->p_dims[last - i] : 1);
+ }
+ --last;
+}
+
+template
+BroadcastIteratorRight::~BroadcastIteratorRight() {
+ delete[] p_cum_shape2;
+ delete[] p_index1_;
+}
+
+template
+bool BroadcastIteratorRight::end() {
+ return p1_ == p1_end;
+}
+
+template
+void BroadcastIteratorRight::next() {
+ ++p_index1_[last];
+ ++p1_;
+ if (last < p_shape2->n_dims && p_shape2->p_dims[last] != 1)
+ ++p2_;
+ N dim = last;
+ while (dim > 0 && p_index1_[dim] >= p_shape1->p_dims[dim]) {
+ p_index1_[dim] = 0;
+ if (dim < p_shape2->n_dims && p_shape2->p_dims[dim] != 1)
+ p2_ -= p_cum_shape2[dim] * p_shape2->p_dims[dim];
+ --dim;
+ ++p_index1_[dim];
+ if (dim < p_shape2->n_dims && p_shape2->p_dims[dim] != 1) {
+ p2_ += p_cum_shape2[dim];
+ }
+ }
+}
+
+template
+void BroadcastMatrixAddLeftInplace(Tensor* X, const Tensor* Y) {
+ BroadcastIteratorRight iter(X->p_shape, Y->p_shape, X->p_values, Y->p_values);
+ while (!iter.end()) {
+ *iter.p1_ += *iter.p2_;
+ iter.next();
+ }
+}
+
+#ifndef SKIP_PYTHON
+
+template
+void BroadcastMatrixAddLeftInplace(py::array_t& X,
+ py::array_t& Y) {
+ std::vector x_dims;
+ arrayshape2vector(x_dims, X);
+ std::vector y_dims;
+ arrayshape2vector(y_dims, Y);
+ TensorShape shape_x(x_dims.size(), x_dims.data());
+ TensorShape shape_y(y_dims.size(), y_dims.data());
+ Tensor vx(&shape_x, X.mutable_data());
+ Tensor vy(&shape_y, Y.mutable_data());
+ BroadcastMatrixAddLeftInplace(&vx, &vy);
+}
+
+void BroadcastMatrixAddLeftInplaceFloat(py::array_t X,
+ py::array_t Y) {
+ BroadcastMatrixAddLeftInplace(X, Y);
+}
+
+void BroadcastMatrixAddLeftInplaceDouble(py::array_t X,
+ py::array_t Y) {
+ BroadcastMatrixAddLeftInplace(X, Y);
+}
+
+void BroadcastMatrixAddLeftInplaceInt64(py::array_t X,
+ py::array_t Y) {
+ BroadcastMatrixAddLeftInplace(X, Y);
+}
+
+#endif
\ No newline at end of file
diff --git a/mlprodict/testing/experimental_c_einsum.h b/mlprodict/testing/experimental_c_einsum.h
new file mode 100644
index 000000000..dc3882fec
--- /dev/null
+++ b/mlprodict/testing/experimental_c_einsum.h
@@ -0,0 +1,62 @@
+#pragma once
+
+// Inspired from
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_regressor.cc.
+
+#include "experimental_c_helper.h"
+
+#include
+#include
+#include
+
+
+template
+void _check_eq(const std::string& eq, const TYPE& sh);
+
+void _split(const std::string& eq, const mapshape_type& sh, mapshape_type& dx);
+
+void _split(const std::string& eq, const std::vector& sh, mapshape_type& dx);
+
+void _equation_split(const std::string& equation, std::string& eqx, std::string& eqy, std::string& eqr);
+
+void _interpret(
+ const mapshape_type& dx, const mapshape_type& dy, const std::string& eqr,
+ mapshape_type& shape, std::vector>& c_uni,
+ std::vector& c_trp, std::vector& c_sum);
+
+void _inc(const mapshape_type& d, mapshape_type& res);
+
+int64_t prod(const mapshape_type& seq);
+
+int64_t get_index(const std::vector& incs, const std::vector& index);
+
+void get_incs(const mapshape_type& cd, const mapshape_type& shape, std::vector& incs);
+
+void mapshape2shape(const mapshape_type& shape, std::vector& out_shape);
+
+void mapshape2shape(const mapshape_type& shape, std::vector& out_shape);
+
+template
+NTYPE vector_dot_product_pointer16(const NTYPE* p1, const NTYPE* p2, size_t size);
+
+std::string code_optimisation();
+
+template <>
+float vector_dot_product_pointer16(const float* p1, const float* p2, size_t size);
+
+template
+NTYPE vector_dot_product_pointer_stride(
+ const NTYPE* xp, const NTYPE* yp, size_t size,
+ int64_t inc_left, int64_t inc_right);
+
+void set_index(int64_t begin, const std::vector& shape_dims, std::vector& index);
+
+template
+void custom_einsum_matmul(
+ const NTYPE* x_data, const NTYPE* y_data,
+ int64_t loop_size,
+ const mapshape_type& cdx, const mapshape_type& cdy, const mapshape_type& shape,
+ const std::vector& left_incs, const std::vector& right_incs,
+ NTYPE* z_data, int64_t begin, int64_t end, char col_sum);
+
+void experimental_ut_einsum();
diff --git a/mlprodict/testing/experimental_c_einsum.hpp b/mlprodict/testing/experimental_c_einsum.hpp
new file mode 100644
index 000000000..e2950d701
--- /dev/null
+++ b/mlprodict/testing/experimental_c_einsum.hpp
@@ -0,0 +1,423 @@
+#pragma once
+
+// Inspired from
+// https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_regressor.cc.
+
+#include "experimental_c_einsum.h"
+
+#include
+#include
+#include
+
+
+template
+void _check_eq(const std::string& eq, const TYPE& sh) {
+ if (eq.size() != sh.size())
+ throw std::runtime_error(MakeString(
+ "Unable to map equation ", eq, " to shape ", sh, "."));
+}
+
+void _split(const std::string& eq, const mapshape_type& sh, mapshape_type& dx) {
+ dx.clear();
+ for (size_t i = 0; i < sh.size(); ++i) {
+ dx.add(eq[i], mapshape_element(sh.at(eq[i]).first, i));
+ }
+}
+
+void _split(const std::string& eq, const std::vector& sh, mapshape_type& dx) {
+ dx.clear();
+ for (size_t i = 0; i < sh.size(); ++i) {
+ dx.add(eq[i], mapshape_element(sh[i], i));
+ }
+}
+
+void _equation_split(const std::string& equation, std::string& eqx, std::string& eqy, std::string& eqr) {
+ size_t comma = equation.find_first_of(",");
+ size_t dash = equation.find_first_of("-", comma);
+ eqx = equation.substr(0, comma);
+ eqy = equation.substr(comma + 1, dash - comma - 1);
+ eqr = equation.substr(dash + 2, equation.size() - dash - 2);
+}
+
+void _interpret(const mapshape_type& dx, const mapshape_type& dy, const std::string& eqr,
+ mapshape_type& shape, std::vector>& c_uni,
+ std::vector& c_trp, std::vector& c_sum) {
+ c_uni.clear();
+ c_trp.clear();
+ c_sum.clear();
+ c_uni.reserve(eqr.size());
+ c_trp.reserve(eqr.size());
+ c_sum.reserve(eqr.size());
+ for (char r : eqr) {
+ if (dx.has_key(r)) {
+ if (dy.has_key(r)) {
+ if (dx.at(r).first != dy.at(r).first)
+ throw std::runtime_error(MakeString(
+ "Dimension mismatch for letter ", r, " dx=", dx, " dy=", dy, "."));
+ c_trp.push_back(r);
+ }
+ else
+ c_uni.push_back(std::pair(r, '#'));
+ }
+ else if (dy.has_key(r))
+ c_uni.push_back(std::pair('#', r));
+ else
+ throw std::runtime_error(MakeString(
+ "Unexpected letter ", r, " in result ", eqr, "."));
+ }
+ for (size_t i = 0; i < dx.size(); ++i) {
+ char c = dx.key(i);
+ if (std::find(eqr.begin(), eqr.end(), c) == eqr.end()) {
+ if (!dy.has_key(c))
+ throw std::runtime_error(MakeString(
+ "Unable to guess what to do with column ", c, " (left side)."));
+ if (dx.at(c).first != dy.at(c).first)
+ throw std::runtime_error(MakeString(
+ "Dimension mismatch for letter ", c, " dx=", dx, " dy=", dy, "."));
+ c_sum.push_back(c);
+ }
+ }
+ for (size_t i = 0; i < dy.size(); ++i) {
+ char c = dy.key(i);
+ if (std::find(eqr.begin(), eqr.end(), c) == eqr.end() && !dx.has_key(c))
+ throw std::runtime_error(MakeString(
+ "Unable to guess what to do with column ", c, " (right side)."));
+ }
+ shape.clear();
+ for (size_t i = 0; i < eqr.size(); ++i) {
+ char r = eqr[i];
+ if (std::find(c_trp.begin(), c_trp.end(), r) != c_trp.end())
+ shape.add(r, mapshape_element(dx.at(r).first, i));
+ else {
+ for (auto p : c_uni) {
+ if (p.first == r) {
+ shape.add(r, mapshape_element(dx.at(r).first, i));
+ break;
+ }
+ if (p.second == r) {
+ shape.add(r, mapshape_element(dy.at(r).first, i));
+ break;
+ }
+ }
+ }
+ }
+ if (shape.size() != eqr.size())
+ throw std::runtime_error(MakeString(
+ "Unable to compute the output shape dx=", dx, "dy=", dy, " eqr=", eqr, " got shape=", shape, "."));
+}
+
+void _inc(const mapshape_type& d, mapshape_type& res) {
+ int64_t t = 1;
+ std::vector> temp;
+ temp.reserve(d.size());
+ for (int i = (int)d.size() - 1; i >= 0; --i) {
+ temp.push_back(std::pair(
+ d.key(i), mapshape_element(t, d.value(i).second)));
+ t *= d.value(i).first;
+ }
+ res.clear();
+ for (auto it = temp.rbegin(); it != temp.rend(); ++it)
+ res.add(it->first, it->second);
+}
+
+int64_t prod(const mapshape_type& seq) {
+ int64_t p = 1;
+ for (size_t i = 0; i < seq.size(); ++i)
+ p *= seq.value(i).first;
+ return p;
+}
+
+int64_t get_index(const std::vector& incs, const std::vector& index) {
+ int64_t ind = 0;
+ for (size_t i = 0; i < index.size(); ++i)
+ ind += incs[i] * index[i];
+ return ind;
+}
+
+void get_incs(const mapshape_type& cd, const mapshape_type& shape,
+ std::vector& incs) {
+ incs.clear();
+ incs.reserve(cd.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ incs.push_back(cd.has_key(shape.key(i)) ? cd.at(shape.key(i)).first : 0);
+}
+
+void mapshape2shape(const mapshape_type& shape, std::vector& out_shape) {
+ out_shape.clear();
+ out_shape.reserve(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ out_shape.push_back(shape.value(i).first);
+}
+
+void mapshape2shape(const mapshape_type& shape, std::vector& out_shape) {
+ out_shape.clear();
+ out_shape.reserve(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ out_shape.push_back(static_cast(shape.value(i).first));
+}
+
+template
+NTYPE vector_dot_product_pointer16(const NTYPE* p1, const NTYPE* p2, size_t size) {
+ NTYPE sum = 0;
+ for (; size != 0; ++p1, ++p2, --size)
+ sum += *p1 * *p2;
+ return sum;
+}
+
+std::string code_optimisation() {
+#if USE_OPENMP
+ std::string omp = MakeString("omp=", omp_get_num_procs());
+#else
+ std::string omp = MakeString("th=", 1);
+#endif
+#if defined(_CMP_EQ_OQ) // defined in immintrin
+ return MakeString("AVX-", omp);
+#else
+ return MakeString("SSE-", omp);
+#endif
+}
+
+template <>
+float vector_dot_product_pointer16(const float* p1, const float* p2, size_t size) {
+ float sum = 0;
+#if defined(__AVX__)
+ if (size > 8) {
+ __m256 r256 = _mm256_setzero_ps();
+ for (; size > 8; p1 += 8, p2 += 8, size -= 8)
+ r256 = _mm256_add_ps(r256, _mm256_mul_ps(_mm256_load_ps(p1), _mm256_load_ps(p2)));
+ __m128 c1, c2, r1;
+ c1 = _mm256_extractf128_ps(r256, 1);
+ c2 = _mm256_extractf128_ps(r256, 0);
+ r1 = _mm_add_ps(c1, c2);
+ c1 = _mm_shuffle_ps(r1, r1, _MM_SHUFFLE(2, 3, 0, 1));
+ c2 = _mm_add_ps(r1, c1);
+ c1 = _mm_movehl_ps(c1, c2);
+ c2 = _mm_add_ss(c2, c1);
+ sum += _mm_cvtss_f32(c2);
+ }
+#else
+ if (size > 4) {
+ __m128 c1, c2;
+ __m128 r1 = _mm_setzero_ps();
+ for (; size > 4; p1 += 4, p2 += 4, size -= 4)
+ r1 = _mm_add_ps(r1, _mm_mul_ps(_mm_load_ps(p1), _mm_load_ps(p2)));
+ c1 = _mm_shuffle_ps(r1, r1, _MM_SHUFFLE(2, 3, 0, 1));
+ c2 = _mm_add_ps(r1, c1);
+ c1 = _mm_movehl_ps(c1, c2);
+ c2 = _mm_add_ss(c2, c1);
+ sum += _mm_cvtss_f32(c2);
+ }
+#endif
+ for (; size != 0; ++p1, ++p2, --size)
+ sum += *p1 * *p2;
+ return sum;
+}
+
+template
+NTYPE vector_dot_product_pointer_stride(const NTYPE* xp, const NTYPE* yp, size_t size,
+ int64_t inc_left, int64_t inc_right) {
+ NTYPE sum = (NTYPE)0;
+ for (int64_t i_loop = size; i_loop != 0; xp += inc_left, yp += inc_right, --i_loop)
+ sum += *xp * *yp;
+ return sum;
+}
+
+void set_index(int64_t begin, const std::vector& shape_dims, std::vector& index) {
+ for (size_t i = shape_dims.size() - 1; i > 0; --i) {
+ index[i] = begin % shape_dims[i];
+ begin -= index[i];
+ begin /= shape_dims[i];
+ }
+ index[0] = begin;
+}
+
+template
+void custom_einsum_matmul(
+ const NTYPE* x_data, const NTYPE* y_data,
+ int64_t loop_size,
+ const mapshape_type& cdx, const mapshape_type& cdy, const mapshape_type& shape,
+ const std::vector& left_incs, const std::vector& right_incs,
+ NTYPE* z_data, int64_t begin, int64_t end, char col_sum) {
+ const NTYPE* xp, * yp;
+ NTYPE* zp;
+ size_t pos;
+ NTYPE* z_end = z_data + end;
+ size_t len_index = shape.size();
+
+ std::vector shape_dims(len_index);
+ for (size_t i = 0; i < len_index; ++i)
+ shape_dims[i] = shape.value(i).first;
+
+ std::vector index(len_index);
+ int64_t i_left_loop, inc_left, i_right_loop, inc_right;
+ set_index(begin, shape_dims, index);
+ i_left_loop = get_index(left_incs, index);
+ i_right_loop = get_index(right_incs, index);
+ inc_left = cdx.at(col_sum).first;
+ inc_right = cdy.at(col_sum).first;
+
+ for (zp = z_data + begin; zp != z_end; ++zp) {
+ // summation
+ xp = x_data + i_left_loop;
+ yp = y_data + i_right_loop;
+
+ if (inc_left == 1 && inc_right == 1) {
+ *zp = vector_dot_product_pointer16(xp, yp, loop_size);
+ }
+ else {
+ *zp = vector_dot_product_pointer_stride(xp, yp, loop_size, inc_left, inc_right);
+ }
+
+ // increment
+ pos = len_index - 1;
+ ++index[pos];
+ i_left_loop += left_incs[pos];
+ i_right_loop += right_incs[pos];
+ while (pos > 0 && index[pos] >= shape_dims[pos]) {
+ i_left_loop -= left_incs[pos] * index[pos];
+ i_right_loop -= right_incs[pos] * index[pos];
+ index[pos] = 0;
+ --pos;
+ ++index[pos];
+ i_left_loop += left_incs[pos];
+ i_right_loop += right_incs[pos];
+ }
+ }
+}
+
+#ifndef SKIP_PYTHON
+
+template
+py::array_t custom_einsum(
+ const std::string& equation,
+ py::array_t x,
+ py::array_t y,
+ int nthread) {
+
+ std::vector x_shape, y_shape;
+ arrayshape2vector(x_shape, x);
+ arrayshape2vector(y_shape, y);
+
+ const NTYPE* x_data = x.data();
+ const NTYPE* y_data = y.data();
+
+ std::string eqx, eqy, eqr;
+ _equation_split(equation, eqx, eqy, eqr);
+ _check_eq(eqx, x_shape);
+ _check_eq(eqy, y_shape);
+ mapshape_type dx, dy;
+ _split(eqx, x_shape, dx);
+ _split(eqy, y_shape, dy);
+
+ mapshape_type shape;
+ std::vector> c_uni;
+ std::vector c_trp, c_sum;
+ _interpret(dx, dy, eqr, shape, c_uni, c_trp, c_sum);
+
+ if (c_sum.size() != 1)
+ throw std::runtime_error(MakeString(
+ "More than one summation indices ", c_sum, " in equation ", equation, "."));
+
+ mapshape_type cdx, cdy;
+ _inc(dx, cdx);
+ _inc(dy, cdy);
+ int64_t full_size = prod(shape);
+
+ std::vector z_vector(full_size);
+ NTYPE* z_data = z_vector.data();
+
+ // loop
+ int64_t loop_size = dx.at(c_sum[0]).first;
+
+ std::vector left_incs, right_incs;
+ get_incs(cdx, shape, left_incs);
+ get_incs(cdy, shape, right_incs);
+
+#if USE_OPENMP
+ if (nthread == 1) {
+#endif
+ custom_einsum_matmul(x_data, y_data, loop_size,
+ cdx, cdy, shape,
+ left_incs, right_incs, z_data,
+ 0 /*begin*/, full_size /*end*/,
+ c_sum[0]);
+#if USE_OPENMP
+ }
+ else {
+ if (nthread > 1)
+ omp_set_num_threads(nthread);
+ else
+ nthread = omp_get_num_procs();
+ int N = nthread * 4;
+ int64_t h = full_size / N;
+ if (h == 0) {
+ h = full_size;
+ N = 1;
+ }
+
+#pragma omp parallel for
+ for (int i = 0; i < N; ++i) {
+ int64_t begin = h * i;
+ int64_t end = (i == N - 1) ? full_size : begin + h;
+ custom_einsum_matmul(x_data, y_data, loop_size,
+ cdx, cdy, shape,
+ left_incs, right_incs, z_data,
+ begin /*begin*/, end /*end*/,
+ c_sum[0]);
+ }
+ }
+#endif
+
+ std::vector z_shape;
+ std::vector strides;
+
+ mapshape2shape(shape, z_shape);
+ shape2strides(z_shape, strides, (NTYPE)0);
+
+ return py::array_t(
+ py::buffer_info(
+ &z_vector[0],
+ sizeof(NTYPE),
+ py::format_descriptor::format(),
+ z_shape.size(),
+ z_shape, /* shape of the matrix */
+ strides /* strides for each axis */
+ ));
+}
+
+py::array_t custom_einsum_float(
+ const std::string& equation,
+ py::array_t x,
+ py::array_t y,
+ int nthread) {
+ return custom_einsum(equation, x, y, nthread);
+}
+
+py::array_t custom_einsum_double(
+ const std::string& equation,
+ py::array_t x,
+ py::array_t y,
+ int nthread) {
+ return custom_einsum(equation, x, y, nthread);
+}
+
+py::array_t custom_einsum_int64(
+ const std::string& equation,
+ py::array_t x,
+ py::array_t y,
+ int nthread) {
+ return custom_einsum(equation, x, y, nthread);
+}
+
+
+py::array_t custom_einsum_int32(
+ const std::string& equation,
+ py::array_t x,
+ py::array_t y,
+ int nthread) {
+ return custom_einsum(equation, x, y, nthread);
+}
+
+#endif
+
+void experimental_ut_einsum();
diff --git a/mlprodict/testing/experimental_c_helper.h b/mlprodict/testing/experimental_c_helper.h
new file mode 100644
index 000000000..9a00332ec
--- /dev/null
+++ b/mlprodict/testing/experimental_c_helper.h
@@ -0,0 +1,172 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include // cout
+#include
+#include
+#include
+#include