diff --git a/HISTORY.rst b/HISTORY.rst
index 143998c49..42f0e2b5b 100644
--- a/HISTORY.rst
+++ b/HISTORY.rst
@@ -5,16 +5,18 @@
History
=======
-current - 2020-02-15 - 0.00Mb
+current - 2020-02-19 - 0.00Mb
=============================
+* `93`: Use pointer for TreeClassifier (2020-02-19)
+* `99`: Fixes #93, use same code for TreeEnsembleClassifier and TreeEnsembleRegression (2020-02-19)
* `98`: mlprodict i broken after onnxruntime, skl2onnx update (2020-02-15)
* `97`: Add runtime for operator Conv (2020-01-24)
* `96`: Fixes #97, add runtime for operator Conv (2020-01-24)
* `95`: Fix OnnxInference where an output and an operator share the same name (2020-01-15)
* `94`: Raw scores are always positive for TreeEnsembleClassifier (binary) (2020-01-13)
-* `86`: Use pointers to replace treeindex in tree ensemble cpp runtime (2019-12-17)
* `90`: Implements a C++ runtime for topk (2019-12-17)
+* `86`: Use pointers to replace treeindex in tree ensemble cpp runtime (2019-12-17)
* `92`: Implements a C++ version of ArrayFeatureExtractor (2019-12-14)
* `89`: Implements a function which extracts some informations on the models (2019-12-14)
* `88`: Fix bug in runtime of GatherElements (2019-12-14)
diff --git a/_doc/notebooks/onnx_tree_ensemble_parallel.ipynb b/_doc/notebooks/onnx_tree_ensemble_parallel.ipynb
new file mode 100644
index 000000000..55f68e29e
--- /dev/null
+++ b/_doc/notebooks/onnx_tree_ensemble_parallel.ipynb
@@ -0,0 +1,918 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TreeEnsembleRegressor and parallelisation\n",
+ "\n",
+ "The operator [TreeEnsembleRegressor](https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md#ai.onnx.ml.TreeEnsembleRegressor) describe any tree model (decision tree, random forest, gradient boosting). The runtime is usually implements in C/C++ and uses parallelisation. The notebook studies the impact of the parallelisation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
run previous cell, wait for 2 seconds
\n",
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from jyquickhelper import add_notebook_menu\n",
+ "add_notebook_menu()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Graph\n",
+ "\n",
+ "The following graph shows the time ratio between two runtimes depending on the number of observations in a batch (N) and the number of trees in the forest."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import numpy\n",
+ "import matplotlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "from matplotlib.colors import LogNorm\n",
+ "\n",
+ "\n",
+ "def heatmap(data, row_labels, col_labels, ax=None,\n",
+ " cbar_kw={}, cbarlabel=\"\", **kwargs):\n",
+ " \"\"\"\n",
+ " Create a heatmap from a numpy array and two lists of labels.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " data\n",
+ " A 2D numpy array of shape (N, M).\n",
+ " row_labels\n",
+ " A list or array of length N with the labels for the rows.\n",
+ " col_labels\n",
+ " A list or array of length M with the labels for the columns.\n",
+ " ax\n",
+ " A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If\n",
+ " not provided, use current axes or create a new one. Optional.\n",
+ " cbar_kw\n",
+ " A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.\n",
+ " cbarlabel\n",
+ " The label for the colorbar. Optional.\n",
+ " **kwargs\n",
+ " All other arguments are forwarded to `imshow`.\n",
+ " \"\"\"\n",
+ "\n",
+ " if not ax:\n",
+ " ax = plt.gca()\n",
+ "\n",
+ " # Plot the heatmap\n",
+ " im = ax.imshow(data, **kwargs)\n",
+ "\n",
+ " # Create colorbar\n",
+ " cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)\n",
+ " cbar.ax.set_ylabel(cbarlabel, rotation=-90, va=\"bottom\")\n",
+ "\n",
+ " # We want to show all ticks...\n",
+ " ax.set_xticks(numpy.arange(data.shape[1]))\n",
+ " ax.set_yticks(numpy.arange(data.shape[0]))\n",
+ " # ... and label them with the respective list entries.\n",
+ " ax.set_xticklabels(col_labels)\n",
+ " ax.set_yticklabels(row_labels)\n",
+ "\n",
+ " # Let the horizontal axes labeling appear on top.\n",
+ " ax.tick_params(top=True, bottom=False,\n",
+ " labeltop=True, labelbottom=False)\n",
+ "\n",
+ " # Rotate the tick labels and set their alignment.\n",
+ " plt.setp(ax.get_xticklabels(), rotation=-30, ha=\"right\",\n",
+ " rotation_mode=\"anchor\")\n",
+ "\n",
+ " # Turn spines off and create white grid.\n",
+ " for edge, spine in ax.spines.items():\n",
+ " spine.set_visible(False)\n",
+ "\n",
+ " ax.set_xticks(numpy.arange(data.shape[1]+1)-.5, minor=True)\n",
+ " ax.set_yticks(numpy.arange(data.shape[0]+1)-.5, minor=True)\n",
+ " ax.grid(which=\"minor\", color=\"w\", linestyle='-', linewidth=3)\n",
+ " ax.tick_params(which=\"minor\", bottom=False, left=False)\n",
+ "\n",
+ " return im, cbar\n",
+ "\n",
+ "\n",
+ "def annotate_heatmap(im, data=None, valfmt=\"{x:.2f}\",\n",
+ " textcolors=[\"black\", \"white\"],\n",
+ " threshold=None, **textkw):\n",
+ " \"\"\"\n",
+ " A function to annotate a heatmap.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " im\n",
+ " The AxesImage to be labeled.\n",
+ " data\n",
+ " Data used to annotate. If None, the image's data is used. Optional.\n",
+ " valfmt\n",
+ " The format of the annotations inside the heatmap. This should either\n",
+ " use the string format method, e.g. \"$ {x:.2f}\", or be a\n",
+ " `matplotlib.ticker.Formatter`. Optional.\n",
+ " textcolors\n",
+ " A list or array of two color specifications. The first is used for\n",
+ " values below a threshold, the second for those above. Optional.\n",
+ " threshold\n",
+ " Value in data units according to which the colors from textcolors are\n",
+ " applied. If None (the default) uses the middle of the colormap as\n",
+ " separation. Optional.\n",
+ " **kwargs\n",
+ " All other arguments are forwarded to each call to `text` used to create\n",
+ " the text labels.\n",
+ " \"\"\"\n",
+ "\n",
+ " if not isinstance(data, (list, numpy.ndarray)):\n",
+ " data = im.get_array()\n",
+ "\n",
+ " # Normalize the threshold to the images color range.\n",
+ " if threshold is not None:\n",
+ " threshold = im.norm(threshold)\n",
+ " else:\n",
+ " threshold = im.norm(data.max())/2.\n",
+ "\n",
+ " kw = dict(horizontalalignment=\"center\",\n",
+ " verticalalignment=\"center\")\n",
+ " kw.update(textkw)\n",
+ "\n",
+ " # Get the formatter in case a string is supplied\n",
+ " if isinstance(valfmt, str):\n",
+ " valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)\n",
+ "\n",
+ " texts = []\n",
+ " for i in range(data.shape[0]):\n",
+ " for j in range(data.shape[1]):\n",
+ " kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])\n",
+ " text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)\n",
+ " texts.append(text)\n",
+ "\n",
+ " return texts\n",
+ "\n",
+ "def plot_metric(metric, ax, xlabel=\"N\", ylabel=\"trees\"):\n",
+ " x = numpy.array(list(sorted(set(k[0] for k in metric))))\n",
+ " y = numpy.array(list(sorted(set(k[1] for k in metric)))) \n",
+ " rx = {v: i for i, v in enumerate(x)}\n",
+ " ry = {v: i for i, v in enumerate(y)}\n",
+ "\n",
+ " X, Y = numpy.meshgrid(x, y)\n",
+ " zm = numpy.zeros(X.shape, dtype=numpy.float64)\n",
+ " for k, v in metric.items():\n",
+ " zm[ry[k[1]], rx[k[0]]] = v\n",
+ "\n",
+ " xs = [str(_) for _ in x]\n",
+ " ys = [str(_) for _ in y]\n",
+ " vmin = min(metric.values())\n",
+ " vmax = max(metric.values())\n",
+ " im, cbar = heatmap(zm, ys, xs, ax=ax, cmap=\"bwr\", cbarlabel=\"ratio\",\n",
+ " norm=LogNorm(vmin=vmin, vmax=vmax))\n",
+ " texts = annotate_heatmap(im, valfmt=\"{x:.2f}x\")\n",
+ " ax.set_xlabel(xlabel)\n",
+ " ax.set_ylabel(ylabel)\n",
+ "\n",
+ " \n",
+ "data = {(1, 1): 0.1, (10, 1): 1, (1, 10): 2, (10, 10): 100, (100, 1): 100, (100, 10): 1000}\n",
+ "\n",
+ "fig, ax = plt.subplots()\n",
+ "plot_metric(data, ax)\n",
+ "fig.tight_layout();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## scikit-learn: T trees vs 1 tree\n",
+ "\n",
+ "Let's do first compare a random forest from *scikit-learn* with 1 tree against multiple trees."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.datasets import make_regression\n",
+ "ntest = 1000\n",
+ "X, y = make_regression(n_samples=10000 + ntest, n_features=10, n_informative=5,\n",
+ " n_targets=1, random_state=11)\n",
+ "X_train, X_test, y_train, y_test = X[:-ntest], X[-ntest:], y[:-ntest], y[-ntest:]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 5/5 [00:02<00:00, 1.33it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sklearn.ensemble import RandomForestRegressor\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "N = [1, 10, 100, 1000, 10000]\n",
+ "T = [1, 2, 10, 20, 50]\n",
+ "\n",
+ "models = {}\n",
+ "for nt in tqdm(T):\n",
+ " rf = RandomForestRegressor(n_estimators=nt, max_depth=5).fit(X_train, y_train)\n",
+ " models[nt] = rf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 5/5 [01:20<00:00, 19.35s/it]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[((1, 1), 0.9760457855545297),\n",
+ " ((10, 1), 0.9147578680294024),\n",
+ " ((100, 1), 0.7756775134240294)]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sklearn\n",
+ "from mlprodict.onnxrt.validate.validate_benchmark import benchmark_fct\n",
+ "\n",
+ "def benchmark(X, fct1, fct2, N, repeat=10, number=20):\n",
+ " \n",
+ " def ti(r, n):\n",
+ " if n <= 1:\n",
+ " return 40 * r\n",
+ " if n <= 10:\n",
+ " return 10 * r\n",
+ " if n <= 100:\n",
+ " return 4 * r\n",
+ " if n <= 1000:\n",
+ " return r\n",
+ " return r // 2\n",
+ " \n",
+ " with sklearn.config_context(assume_finite=True):\n",
+ " # to warm up the engine\n",
+ " time_kwargs = {n: dict(repeat=10, number=10) for n in N}\n",
+ " benchmark_fct(fct1, X, time_kwargs=time_kwargs, skip_long_test=False)\n",
+ " benchmark_fct(fct2, X, time_kwargs=time_kwargs, skip_long_test=False)\n",
+ " # real measure\n",
+ " time_kwargs = {n: dict(repeat=ti(repeat, n), number=number) for n in N}\n",
+ " res1 = benchmark_fct(fct1, X, time_kwargs=time_kwargs, skip_long_test=False)\n",
+ " res2 = benchmark_fct(fct2, X, time_kwargs=time_kwargs, skip_long_test=False)\n",
+ " res = {}\n",
+ " for r in sorted(res1):\n",
+ " r1 = res1[r]\n",
+ " r2 = res2[r]\n",
+ " ratio = r2['total'] / r1['total']\n",
+ " res[r] = ratio\n",
+ " return res\n",
+ "\n",
+ "\n",
+ "def tree_benchmark(X, fct1, fct2, T, N, repeat=20, number=10):\n",
+ " bench = {}\n",
+ " for t in tqdm(T):\n",
+ " r = benchmark(X, fct1(t), fct2(t), N, repeat=repeat, number=number)\n",
+ " for n, v in r.items():\n",
+ " bench[n, t] = v\n",
+ " return bench\n",
+ "\n",
+ "bench = tree_benchmark(X_test.astype(numpy.float32),\n",
+ " lambda t: models[1].predict,\n",
+ " lambda t: models[t].predict, T, N)\n",
+ "\n",
+ "list(bench.items())[:3]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, ax = plt.subplots()\n",
+ "plot_metric(bench, ax)\n",
+ "fig.tight_layout();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "As expected, all ratio on first line are close to 1 since both models are the same. fourth line, second column (T=20, N=10) means a random forest with 20 trees is around 5 times slower to compute the predictions of 10 observations in a batch compare to a random forest with 1 tree."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## scikit-learn against onnxuntime"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from skl2onnx import to_onnx\n",
+ "X32 = X_test.astype(numpy.float32)\n",
+ "models_onnx = {t: to_onnx(m, X32[:1]) for t, m in models.items()}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from onnxruntime import InferenceSession\n",
+ "sess_models = {t: InferenceSession(mo.SerializeToString()) for t, mo in models_onnx.items()}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 5/5 [01:27<00:00, 21.10s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "bench_ort = tree_benchmark(X_test.astype(numpy.float32),\n",
+ " lambda t: models[t].predict,\n",
+ " lambda t: (lambda x, t_=t, se=sess_models: se[t_].run(None, {'X': x})),\n",
+ " T, N)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, ax = plt.subplots()\n",
+ "plot_metric(bench_ort, ax)\n",
+ "ax.set_title(\"scikit-learn vs onnxruntime\\n < 1 means onnxruntime is faster\")\n",
+ "fig.tight_layout();"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'1.1.997'"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from onnxruntime import __version__ as ort_version\n",
+ "ort_version"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We clearly see this version of onnxruntime is fast for small batches, still faster but not that much for big batches."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Implementation details for mlrodict runtime\n",
+ "\n",
+ "The runtime implemented in [mlprodict]() mostly relies on two files:\n",
+ "* [op_tree_ensemble_common_p_agg_.hpp](https://github.com/sdpython/mlprodict/blob/master/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_common_p_agg_.hpp)\n",
+ "* [op_tree_ensemble_common_p_.hpp](https://github.com/sdpython/mlprodict/blob/master/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_common_p_.hpp)\n",
+ "\n",
+ "The runtime builds a tree structure, computes the output of every tree and then agregates them. The implementation distringuishes when the batch size contains only 1 observations or many. It parallelizes on the following conditions:\n",
+ "* if the batch size $N \\geqslant N_0$, it then parallizes per observation, asuming every one is independant,\n",
+ "* if the batch size $N = 1$ and the number of trees $T \\geqslant T_0$, it then parallizes per tree."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## scikit-learn against mlprodict, no parallelisation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mlprodict.onnxrt import OnnxInference\n",
+ "oinf_models = {t: OnnxInference(mo, runtime=\"python_compiled\") for t, mo in models_onnx.items()}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's disable the parallelisation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for _, oinf in oinf_models.items():\n",
+ " oinf.sequence_[0].ops_.rt_.omp_tree_ = 10000000\n",
+ " oinf.sequence_[0].ops_.rt_.omp_N_ = 10000000"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 5/5 [01:14<00:00, 18.38s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "bench_mlp = tree_benchmark(X_test.astype(numpy.float32),\n",
+ " lambda t: models[t].predict,\n",
+ " lambda t: (lambda x, t_=t, oi=oinf_models: oi[t_].run({'X': x})),\n",
+ " T, N)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, ax = plt.subplots()\n",
+ "plot_metric(bench_mlp, ax)\n",
+ "ax.set_title(\"scikit-learn vs mlprodict\\n < 1 means mlprodict is faster\")\n",
+ "fig.tight_layout();"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'0.3'"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from mlprodict import __version__ as mlp_version\n",
+ "mlp_version"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This implementation seems to be faster. And with parallelisation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for _, oinf in oinf_models.items():\n",
+ " oinf.sequence_[0].ops_.rt_.omp_tree_ = 2\n",
+ " oinf.sequence_[0].ops_.rt_.omp_N_ = 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 5/5 [01:09<00:00, 17.04s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "bench_mlp_para = tree_benchmark(X_test.astype(numpy.float32),\n",
+ " lambda t: models[t].predict,\n",
+ " lambda t: (lambda x, t_=t, oi=oinf_models: oi[t_].run({'X': x})),\n",
+ " T, N)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, ax = plt.subplots()\n",
+ "plot_metric(bench_mlp_para, ax)\n",
+ "ax.set_title(\"scikit-learn vs mlprodict\\n < 1 means mlprodict is faster\\nparallelisation\")\n",
+ "fig.tight_layout();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Parallelisation does improve the computation time when N is big. Let's compare with and without parallisation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "bench_para = {}\n",
+ "for k, v in bench_mlp.items():\n",
+ " bench_para[k] = bench_mlp_para[k] / v"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, ax = plt.subplots()\n",
+ "plot_metric(bench_para, ax)\n",
+ "ax.set_title(\"mlprodict vs mlprodict parallelized\\n < 1 means parallelisation is faster\")\n",
+ "fig.tight_layout();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Parallisation per trees does not seem to be efficient. Let's confirm with a proper benchmark."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 5/5 [00:10<00:00, 2.56s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "for _, oinf in oinf_models.items():\n",
+ " oinf.sequence_[0].ops_.rt_.omp_tree_ = 1000000\n",
+ " oinf.sequence_[0].ops_.rt_.omp_N_ = 1000000\n",
+ "\n",
+ "oinf_models_para = {t: OnnxInference(mo, runtime=\"python_compiled\") for t, mo in models_onnx.items()}\n",
+ "for _, oinf in oinf_models_para.items():\n",
+ " oinf.sequence_[0].ops_.rt_.omp_tree_ = 2\n",
+ " oinf.sequence_[0].ops_.rt_.omp_N_ = 2\n",
+ "\n",
+ "bench_mlp_para = tree_benchmark(X_test.astype(numpy.float32),\n",
+ " lambda t: (lambda x, t_=t, oi=oinf_models: oi[t_].run({'X': x})),\n",
+ " lambda t: (lambda x, t_=t, oi=oinf_models_para: oi[t_].run({'X': x})),\n",
+ " T, N)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, ax = plt.subplots()\n",
+ "plot_metric(bench_mlp_para, ax)\n",
+ "ax.set_title(\"mlprodict vs mlprodict parallelized\\n < 1 means parallelisation is faster\\nsame baseline\")\n",
+ "fig.tight_layout();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "It should be run on different machines. On the current one, parallelisation per trees (when N=1) does not seem to help. Parallisation for a small number of observations does not seem to help either. So we need to find some threshold."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
\ No newline at end of file