diff --git a/_doc/notebooks/onnx_fft.ipynb b/_doc/notebooks/onnx_fft.ipynb
new file mode 100644
index 000000000..6efb475e2
--- /dev/null
+++ b/_doc/notebooks/onnx_fft.ipynb
@@ -0,0 +1,822 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "51bc89fc",
+ "metadata": {},
+ "source": [
+ "# ONNX and FFT\n",
+ "\n",
+ "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "7b2add97",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from jyquickhelper import add_notebook_menu\n",
+ "add_notebook_menu()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "acfdc3b0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext mlprodict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "abb5fa88",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'1.21.0'"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy\n",
+ "numpy.__version__"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2e4f68e4",
+ "metadata": {},
+ "source": [
+ "## Python implementation of RFFT\n",
+ "\n",
+ "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "cb1cc910",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[ 0.92935219+0.j , 1.1166406 +0.18610885j,\n",
+ " 2.98881347-0.86137828j, 0.57062752-3.17075076j],\n",
+ " [-0.81071034+0.j , 4.04571912+1.34415298j,\n",
+ " -0.75316593+1.87375117j, -3.73972034+1.19963451j],\n",
+ " [ 0.49893169+0.j , -2.38853745+0.91784964j,\n",
+ " -2.3230939 +2.42467461j, 2.84973582+0.96874118j],\n",
+ " [-0.85518897+0.j , -1.07457921+2.14618057j,\n",
+ " 0.67522719-2.17320735j, 1.31480887+2.2782433j ],\n",
+ " [ 2.80867666+0.j , -2.79453396-2.22901834j,\n",
+ " 0.492986 +0.10661537j, 2.65317564+0.57651319j]])"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy\n",
+ "\n",
+ "\n",
+ "def almost_equal(a, b, error=1e-5):\n",
+ " \"\"\"\n",
+ " The function compares two matrices, one may be complex. In that case,\n",
+ " this matrix is changed into a new matrix with a new first dimension,\n",
+ " [0,::] means real part, [1,::] means imaginary part.\n",
+ " \"\"\"\n",
+ " if a.dtype in (numpy.complex64, numpy.complex128):\n",
+ " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n",
+ " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n",
+ " new_a[0] = numpy.real(a)\n",
+ " new_a[1] = numpy.imag(a)\n",
+ " return almost_equal(new_a, b, error)\n",
+ " if b.dtype in (numpy.complex64, numpy.complex128):\n",
+ " return almost_equal(b, a, error)\n",
+ " if a.shape != b.shape:\n",
+ " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n",
+ " diff = numpy.abs(a.ravel() - b.ravel()).max()\n",
+ " if diff > error:\n",
+ " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n",
+ "\n",
+ "\n",
+ "def dft_real_cst(N, fft_length):\n",
+ " n = numpy.arange(N)\n",
+ " k = n.reshape((N, 1)).astype(numpy.float64)\n",
+ " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n",
+ " both = numpy.empty((2,) + M.shape)\n",
+ " both[0, :, :] = numpy.real(M)\n",
+ " both[1, :, :] = numpy.imag(M)\n",
+ " return both\n",
+ "\n",
+ "\n",
+ "def dft_real(x, fft_length=None, transpose=True):\n",
+ " if len(x.shape) == 1:\n",
+ " x = x.reshape((1, -1))\n",
+ " N = 1\n",
+ " else:\n",
+ " N = x.shape[0] \n",
+ " C = x.shape[-1] if transpose else x.shape[-2]\n",
+ " if fft_length is None:\n",
+ " fft_length = x.shape[-1]\n",
+ " size = fft_length // 2 + 1\n",
+ "\n",
+ " cst = dft_real_cst(C, fft_length)\n",
+ " if transpose:\n",
+ " x = numpy.transpose(x, (1, 0))\n",
+ " res = numpy.dot(cst[:, :, :fft_length], x[:fft_length])[:, :size, :]\n",
+ " return numpy.transpose(res, (0, 2, 1))\n",
+ " else:\n",
+ " return numpy.dot(cst[:, :, :fft_length], x[:fft_length])\n",
+ "\n",
+ "\n",
+ "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n",
+ "fft_np = numpy.fft.rfft(rnd)\n",
+ "fft_cus = dft_real(rnd)\n",
+ "fft_np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0c052ea1",
+ "metadata": {},
+ "source": [
+ "Function `almost_equal` verifies both functions return the same results."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "3ca040cb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "almost_equal(fft_np, fft_cus)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7fe77440",
+ "metadata": {},
+ "source": [
+ "Let's do the same with `fft_length < shape[1]`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "3a747a4a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[ 0.58212829+0.j , 1.91211772-1.78320393j],\n",
+ " [-0.3185378 +0.j , -0.20609781-1.18129868j],\n",
+ " [-0.81120646+0.j , -0.28543806+3.05769342j],\n",
+ " [-1.06384408+0.j , 0.74100591+0.43276681j],\n",
+ " [ 1.77509081+0.j , -0.13498855+1.82011058j]])"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "fft_np3 = numpy.fft.rfft(rnd, n=3)\n",
+ "fft_cus3 = dft_real(rnd, fft_length=3)\n",
+ "fft_np3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "0db6247b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "almost_equal(fft_np3, fft_cus3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "31a6ac9c",
+ "metadata": {},
+ "source": [
+ "## RFFT in ONNX\n",
+ "\n",
+ "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "efb67b9b",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[[ 0.92935216, 1.1166406 , 2.9888134 , 0.5706275 ],\n",
+ " [-0.81071043, 4.045719 , -0.7531659 , -3.7397203 ],\n",
+ " [ 0.4989317 , -2.3885374 , -2.323094 , 2.849736 ],\n",
+ " [-0.85518885, -1.0745792 , 0.6752271 , 1.3148088 ],\n",
+ " [ 2.8086765 , -2.794534 , 0.49298596, 2.6531756 ]],\n",
+ "\n",
+ " [[ 0. , 0.18610872, -0.8613782 , -3.1707506 ],\n",
+ " [ 0. , 1.3441529 , 1.8737512 , 1.1996344 ],\n",
+ " [ 0. , 0.9178499 , 2.4246747 , 0.96874106],\n",
+ " [ 0. , 2.1461806 , -2.1732073 , 2.2782433 ],\n",
+ " [ 0. , -2.2290184 , 0.10661539, 0.5765133 ]]],\n",
+ " dtype=float32)"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from typing import Any\n",
+ "import mlprodict.npy.numpy_onnx_impl as npnx\n",
+ "from mlprodict.npy import onnxnumpy_np\n",
+ "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n",
+ "# from mlprodict.onnxrt import OnnxInference\n",
+ "\n",
+ "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n",
+ "def onnx_rfft(x, fft_length=None):\n",
+ " if fft_length is None:\n",
+ " raise RuntimeError(\"fft_length must be specified.\")\n",
+ " \n",
+ " size = fft_length // 2 + 1\n",
+ " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n",
+ " xt = npnx.transpose(x, (1, 0))\n",
+ " res = npnx.dot(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n",
+ " return npnx.transpose(res, (0, 2, 1))\n",
+ "\n",
+ "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n",
+ "fft_onx"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "c4b6b1a5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "almost_equal(fft_cus, fft_onx)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a8c35327",
+ "metadata": {},
+ "source": [
+ "The corresponding ONNX graph is the following:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "4d1a85b0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "key = list(onnx_rfft.signed_compiled)[0]\n",
+ "%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "6cf18aca",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fft_onx3 = onnx_rfft(rnd, fft_length=3)\n",
+ "almost_equal(fft_cus3, fft_onx3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6b466fd4",
+ "metadata": {},
+ "source": [
+ "## FFT 2D\n",
+ "\n",
+ "Below the code for complex features."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "e0020084",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[-2.036582 +0.j , -0.85992725+6.47780438j,\n",
+ " -3.99332006-3.11192536j, -1.32368431-2.48821071j],\n",
+ " [ 4.37345155-7.03173815j, -3.14890126+1.59632335j,\n",
+ " -3.75306979+0.66651699j, -1.56716114+4.75028368j],\n",
+ " [ 2.76767016+1.25297955j, 5.07926144-2.23393831j,\n",
+ " 2.41908275-8.55451105j, -8.84556476-1.29356088j],\n",
+ " [ 2.76767016-1.25297955j, 2.41782872-4.44962381j,\n",
+ " -3.6501426 +4.13120322j, 4.30875103+0.96179243j],\n",
+ " [ 4.37345155+7.03173815j, 1.77135529+4.4385736j ,\n",
+ " 2.40878105+5.40109054j, -1.65462983+0.2149866j ]])"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def _DFT_cst(N, fft_length, trunc=True):\n",
+ " n = numpy.arange(N)\n",
+ " k = n.reshape((N, 1)).astype(numpy.float64)\n",
+ " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n",
+ " return M[:fft_length // 2 + 1] if trunc else M\n",
+ "\n",
+ "def DFT(x, fft_length=None, axis=1):\n",
+ " if axis == 1:\n",
+ " x = x.T\n",
+ " if fft_length is None:\n",
+ " fft_length = x.shape[0]\n",
+ " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n",
+ " if axis == 1:\n",
+ " return numpy.dot(cst, x).T\n",
+ " return numpy.dot(cst, x)\n",
+ "\n",
+ "def fft2d_(mat, fft_length):\n",
+ " mat = mat[:fft_length[0], :fft_length[1]]\n",
+ " res = mat.copy()\n",
+ " res = DFT(res, fft_length[1], axis=1)\n",
+ " res = DFT(res, fft_length[0], axis=0)\n",
+ " return res[:fft_length[0], :fft_length[1]//2 + 1]\n",
+ "\n",
+ "\n",
+ "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n",
+ "fft2d_np_ = fft2d_(rnd, rnd.shape)\n",
+ "fft2d_np = numpy.fft.rfft2(rnd)\n",
+ "fft2d_np_"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "777d2775",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "almost_equal(fft2d_np_, fft2d_np)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cfbbe2fd",
+ "metadata": {},
+ "source": [
+ "It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n",
+ "\n",
+ "* $y = FFT(x, axis=1)$\n",
+ "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n",
+ "* $z = z_r + i z_i$\n",
+ "\n",
+ "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "dd4fc711",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def fft2d(mat, fft_length):\n",
+ " mat = mat[:fft_length[0], :fft_length[1]]\n",
+ " res = mat.copy()\n",
+ " \n",
+ " # first FFT\n",
+ " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n",
+ " \n",
+ " # second FFT decomposed on FFT on real part and imaginary part\n",
+ " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n",
+ " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n",
+ " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n",
+ " res = res2_real + res2_imag2\n",
+ " size = fft_length[1]//2 + 1\n",
+ " return res[:, :fft_length[0], :size]\n",
+ "\n",
+ "\n",
+ "fft2d_np = numpy.fft.rfft2(rnd)\n",
+ "fft2d_cus = fft2d(rnd, rnd.shape)\n",
+ "almost_equal(fft2d_np, fft2d_cus)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "bb8667e6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[-2.036582 +0.j , -0.85992725+6.47780438j,\n",
+ " -3.99332006-3.11192536j, -1.32368431-2.48821071j],\n",
+ " [ 4.37345155-7.03173815j, -3.14890126+1.59632335j,\n",
+ " -3.75306979+0.66651699j, -1.56716114+4.75028368j],\n",
+ " [ 2.76767016+1.25297955j, 5.07926144-2.23393831j,\n",
+ " 2.41908275-8.55451105j, -8.84556476-1.29356088j],\n",
+ " [ 2.76767016-1.25297955j, 2.41782872-4.44962381j,\n",
+ " -3.6501426 +4.13120322j, 4.30875103+0.96179243j],\n",
+ " [ 4.37345155+7.03173815j, 1.77135529+4.4385736j ,\n",
+ " 2.40878105+5.40109054j, -1.65462983+0.2149866j ]])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "fft2d_np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "56a94d97",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[[-2.036582 , -0.85992725, -3.99332006, -1.32368431],\n",
+ " [ 4.37345155, -3.14890126, -3.75306979, -1.56716114],\n",
+ " [ 2.76767016, 5.07926144, 2.41908275, -8.84556476],\n",
+ " [ 2.76767016, 2.41782872, -3.6501426 , 4.30875103],\n",
+ " [ 4.37345155, 1.77135529, 2.40878105, -1.65462983]],\n",
+ "\n",
+ " [[ 0. , 6.47780438, -3.11192536, -2.48821071],\n",
+ " [-7.03173815, 1.59632335, 0.66651699, 4.75028368],\n",
+ " [ 1.25297955, -2.23393831, -8.55451105, -1.29356088],\n",
+ " [-1.25297955, -4.44962381, 4.13120322, 0.96179243],\n",
+ " [ 7.03173815, 4.4385736 , 5.40109054, 0.2149866 ]]])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "fft2d_cus"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "faa21909",
+ "metadata": {},
+ "source": [
+ "And with a different `fft_length`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "bf98995f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n",
+ "fft2d_cus = fft2d(rnd, (4, 6))\n",
+ "almost_equal(fft2d_np[:4, :], fft2d_cus)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "caee1f84",
+ "metadata": {},
+ "source": [
+ "## FFT 2D in ONNX\n",
+ "\n",
+ "We use again the numpy API for ONNX."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "ca641274",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n",
+ " if fft_length is None:\n",
+ " raise RuntimeError(\"fft_length must be specified.\")\n",
+ " \n",
+ " size = fft_length // 2 + 1\n",
+ " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n",
+ " if transpose:\n",
+ " xt = npnx.transpose(x, (1, 0))\n",
+ " res = npnx.dot(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n",
+ " return npnx.transpose(res, (0, 2, 1))\n",
+ " else:\n",
+ " return npnx.dot(cst[:, :, :fft_length], x[:fft_length])\n",
+ "\n",
+ "\n",
+ "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n",
+ "def onnx_rfft_2d(x, fft_length=None):\n",
+ " mat = x[:fft_length[0], :fft_length[1]]\n",
+ " \n",
+ " # first FFT\n",
+ " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n",
+ " \n",
+ " # second FFT decomposed on FFT on real part and imaginary part\n",
+ " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n",
+ " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n",
+ " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n",
+ " res = res2_real + res2_imag2\n",
+ " size = fft_length[1]//2 + 1\n",
+ " return res[:, :fft_length[0], :size]\n",
+ "\n",
+ "\n",
+ "fft2d_cus = fft2d(rnd, rnd.shape)\n",
+ "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n",
+ "almost_equal(fft2d_cus, fft2d_onx)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "20fcd8a9",
+ "metadata": {},
+ "source": [
+ "The corresponding ONNX graph."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "b1379b06",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "key = list(onnx_rfft_2d.signed_compiled)[0]\n",
+ "%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3a747f0c",
+ "metadata": {},
+ "source": [
+ "With a different `fft_length`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "16732cbb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fft2d_cus = fft2d(rnd, (4, 5))\n",
+ "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n",
+ "almost_equal(fft2d_cus, fft2d_onx)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "04924e7d",
+ "metadata": {},
+ "source": [
+ "This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "faeff9cd",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file
diff --git a/_unittests/ut_documentation/test_run_notebooks_fft.py b/_unittests/ut_documentation/test_run_notebooks_fft.py
new file mode 100644
index 000000000..8525d3bdd
--- /dev/null
+++ b/_unittests/ut_documentation/test_run_notebooks_fft.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+"""
+@brief test log(time=7s)
+"""
+import os
+import unittest
+from pyquickhelper.loghelper import fLOG
+from pyquickhelper.ipythonhelper import test_notebook_execution_coverage
+from pyquickhelper.pycode import (
+ add_missing_development_version, ExtTestCase)
+import mlprodict
+
+
+class TestNotebookFFT(ExtTestCase):
+
+ def setUp(self):
+ add_missing_development_version(["jyquickhelper"], __file__, hide=True)
+
+ def test_notebook_fft(self):
+ fLOG(
+ __file__,
+ self._testMethodName,
+ OutputPrint=__name__ == "__main__")
+
+ self.assertNotEmpty(mlprodict is not None)
+ folder = os.path.join(os.path.dirname(__file__),
+ "..", "..", "_doc", "notebooks")
+ test_notebook_execution_coverage(__file__, "fft", folder,
+ this_module_name="mlprodict", fLOG=fLOG)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/_unittests/ut_npy/test_numpy_onnx_pyrt.py b/_unittests/ut_npy/test_numpy_onnx_pyrt.py
index b4ed94fc8..40685d1a2 100644
--- a/_unittests/ut_npy/test_numpy_onnx_pyrt.py
+++ b/_unittests/ut_npy/test_numpy_onnx_pyrt.py
@@ -401,6 +401,14 @@ def test_tanh_float32(self):
doc = nxnpy.tanh.__doc__
self.assertIn('tanh', doc)
+ def test_transpose_float32(self):
+ np_tr = lambda x, perm=None: numpy.transpose(x, perm)
+ x = numpy.array([[0.5, 0.1], [-0.5, -0.1], [1, 1]],
+ dtype=numpy.float32)
+ self.common_test1(
+ x, np_tr, nxnpy.transpose, # pylint: disable=E1101
+ FctVersion((numpy.float32, ), ((1, 0), )), perm=(1, 0))
+
def test_unsqueeze_float32(self):
x = numpy.array([[0.5, 0.1], [-0.5, -0.1]], dtype=numpy.float32)
axes = numpy.array([0], dtype=numpy.int64)
@@ -425,5 +433,5 @@ def test_where_float32(self):
if __name__ == "__main__":
- TestNumpyOnnxFunction().test_where_float32()
+ # TestNumpyOnnxFunction().test_where_float32()
unittest.main()
diff --git a/_unittests/ut_npy/test_onnx_variable.py b/_unittests/ut_npy/test_onnx_variable.py
index b69118ad5..17b02702c 100644
--- a/_unittests/ut_npy/test_onnx_variable.py
+++ b/_unittests/ut_npy/test_onnx_variable.py
@@ -1,736 +1,775 @@
-# -*- coding: utf-8 -*-
-"""
-@brief test log(time=3s)
-"""
-import unittest
-from typing import Any
-import numpy
-from pyquickhelper.pycode import ExtTestCase, ignore_warnings
-from mlprodict.npy import onnxnumpy, onnxnumpy_default, onnxnumpy_np
-import mlprodict.npy.numpy_onnx_impl as nxnp
-from mlprodict.npy import (
- OnnxNumpyCompiler as ONC, NDArray, NDArraySameTypeSameShape)
-
-
-@ignore_warnings(DeprecationWarning)
-def get_bool(unused):
- try:
- return numpy.bool_
- except AttributeError:
- return bool
-
-
-numpy_bool = get_bool(None)
-
-
-@onnxnumpy_default
-def test_abs(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy abs"
- return nxnp.abs(x)
-
-
-@onnxnumpy_default
-def test_abs_abs(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy abs abs"
- return nxnp.abs(nxnp.abs(x))
-
-
-@onnxnumpy_default
-def test_abs_add(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy addition"
- return nxnp.abs(x) + x
-
-
-@onnxnumpy_default
-def test_abs_add4(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy addition"
- x2 = x * x
- return x2 * x2
-
-
-@onnxnumpy_default
-def test_abs_addm(x1: NDArray[Any, numpy.float32],
- x2: NDArray[Any, numpy.float32]
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy addition"
- return nxnp.abs(x1) + x2
-
-
-@onnxnumpy_default
-def test_abs_add2(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy addition"
- return nxnp.abs(x) + numpy.float32(2)
-
-
-@onnxnumpy_default
-def test_abs_sub(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy addition"
- return nxnp.abs(x) - x
-
-
-@onnxnumpy_default
-def test_abs_mul(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy addition"
- return nxnp.abs(x) * x
-
-
-@onnxnumpy_default
-def test_abs_pow(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy power"
- return nxnp.abs(x) ** numpy.float32(2)
-
-
-@onnxnumpy_default
-def test_abs_mod(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy modulo"
- return nxnp.abs(x) % numpy.float32(2)
-
-
-@onnxnumpy_default
-def test_abs_matmul(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy addition"
- return nxnp.abs(x) @ x
-
-
-@onnxnumpy_default
-def test_abs_div(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy division"
- return nxnp.abs(x) / x
-
-
-@onnxnumpy_default
-def test_abs_idiv(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.int64]:
- "onnx numpy int division"
- return nxnp.abs(x).astype(numpy.int64) // x.astype(numpy.int64)
-
-
-@onnxnumpy_default
-def test_abs_equal(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy equality"
- return nxnp.abs(x) == x
-
-
-@onnxnumpy_default
-def test_abs_not_equal(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy inequality"
- return nxnp.abs(x) != x
-
-
-@onnxnumpy_default
-def test_abs_greater(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy greater"
- return nxnp.abs(x) > x
-
-
-@onnxnumpy_default
-def test_abs_greater_or_equal(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy greater or equal"
- return nxnp.abs(x) >= x
-
-
-@onnxnumpy_default
-def test_abs_less(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy less"
- return nxnp.abs(x) < x
-
-
-@onnxnumpy_default
-def test_abs_less_or_equal(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy less or equal"
- return nxnp.abs(x) <= x
-
-
-@onnxnumpy_default
-def test_abs_and(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy and"
- return (nxnp.abs(x) < x) and (nxnp.abs(x) < numpy.float32(0))
-
-
-@onnxnumpy_default
-def test_abs_and2(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy and"
- return (nxnp.abs(x) < x) & (nxnp.abs(x) < numpy.float32(0))
-
-
-@onnxnumpy_default
-def test_abs_or(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy or"
- return (nxnp.abs(x) < x) or (nxnp.abs(x) < numpy.float32(0))
-
-
-@onnxnumpy_default
-def test_abs_or2(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy_bool]:
- "onnx numpy or"
- return (nxnp.abs(x) < x) | (nxnp.abs(x) < numpy.float32(0))
-
-
-@onnxnumpy_default
-def test_abs_sum1(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy sum"
- return nxnp.sum(nxnp.abs(x), axis=0)
-
-
-@onnxnumpy_default
-def test_abs_sum2(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy sum"
- return nxnp.sum(nxnp.abs(x), axis=1, keepdims=1)
-
-
-@onnxnumpy_default
-def test_abs_transpose_t(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy transpose T"
- return nxnp.abs(x).T
-
-
-@onnxnumpy_default
-def test_abs_cast(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.int64]:
- "onnx numpy cast"
- return nxnp.abs(x).astype(numpy.int64)
-
-
-@onnxnumpy_default
-def test_abs_reshape(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy reshape"
- return nxnp.abs(x).reshape((-1, 1))
-
-
-@onnxnumpy(op_version=11)
-def test_abs_reshape_11(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy reshape with opset 11"
- return nxnp.abs(x).reshape((-1, 1))
-
-
-@onnxnumpy_default
-def test_abs_slice(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy slice 1"
- return nxnp.abs(x)[:, 1]
-
-
-@onnxnumpy_default
-def test_abs_slice2(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy slice 2"
- return nxnp.abs(x)[:1, 1]
-
-
-@onnxnumpy_default
-def test_abs_slice23(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy slice 23"
- return nxnp.abs(x)[::2, ::3]
-
-
-@onnxnumpy_default
-def test_abs_neg(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy neg"
- return - nxnp.abs(x)
-
-
-@onnxnumpy_default
-def test_abs_not(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.bool_]:
- "onnx numpy not"
- temp = nxnp.abs(x) > numpy.float32(0)
- return temp.not_()
-
-
-@onnxnumpy_default
-def test_abs_filter(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy filter"
- return nxnp.abs(x)[x[:, 0] > numpy.float32(15)]
-
-
-@onnxnumpy_default
-def test_log(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy log"
- return nxnp.log(x)
-
-
-@onnxnumpy_np(signature=NDArraySameTypeSameShape("floats"))
-def test_abs_log_multi(x):
- "onnx numpy log multiple type"
- return nxnp.log(nxnp.abs(x))
-
-
-@onnxnumpy_np(signature=NDArraySameTypeSameShape("floats"))
-def test_abs_log_multi_dtype(x):
- "onnx numpy log multiple type"
- return nxnp.log(nxnp.abs(x) + x.dtype(1))
-
-
-@onnxnumpy_default
-def test_abs_shape(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.int64]:
- "onnx numpy shape"
- return nxnp.abs(x).shape
-
-
-@onnxnumpy_default
-def test_abs_size(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.int64]:
- "onnx numpy size"
- return nxnp.abs(x).size
-
-
-@onnxnumpy_default
-def test_abs_flatten(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.int64]:
- "onnx numpy flatten"
- return nxnp.abs(x).flatten()
-
-
-@onnxnumpy_default
-def test_abs_flatten2(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.int64]:
- "onnx numpy flatten"
- return nxnp.abs(x).flatten(axis=1)
-
-
-@onnxnumpy_default
-def test_abs_set1a(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy set"
- temp = nxnp.abs(x).copy()
- temp[2] = numpy.float32(-1.5)
- return temp
-
-
-@onnxnumpy_default
-def test_abs_set1b(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy set"
- temp = nxnp.abs(x).copy()
- temp[:4] = numpy.float32(-1.5)
- return temp
-
-
-@onnxnumpy_default
-def test_abs_set1c(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy set"
- temp = nxnp.abs(x).copy()
- temp[:4:2] = numpy.float32(-1.5)
- return temp
-
-
-@onnxnumpy_default
-def test_abs_set1d(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy set"
- temp = nxnp.abs(x).copy()
- temp[:4:2] = numpy.array([-1.5, -1.6], dtype=numpy.float32)
- return temp
-
-
-@onnxnumpy_default
-def test_abs_set1e(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy set"
- temp = nxnp.abs(x).copy()
- temp[2:] = numpy.float32(-1.5)
- return temp
-
-
-@onnxnumpy_default
-def test_abs_set1f(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy set"
- temp = nxnp.abs(x).copy()
- temp[3:5] = numpy.float32(-1.5)
- return temp
-
-
-@onnxnumpy_default
-def test_abs_set1g(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy set"
- temp = nxnp.abs(x).copy()
- temp[3:] = numpy.array([-1.5] * 4, dtype=numpy.float32)
- return temp
-
-
-@onnxnumpy_default
-def test_abs_set1h(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy set"
- cp = x.copy()
- cp[x < numpy.float32(0)] = numpy.array([-1], dtype=numpy.float32)
- return cp
-
-
-@onnxnumpy_default
-def test_abs_set1i(x: NDArray[Any, numpy.float32],
- ) -> NDArray[Any, numpy.float32]:
- "onnx numpy set"
- cp = x.copy()
- z = x < numpy.float32(0)
- cp[z] = -x
- return cp
-
-
-class TestOnnxVariable(ExtTestCase):
-
- def test_py_abs(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs(x)
- self.assertEqualArray(y, numpy.abs(x))
- self.assertEqual(test_abs.__doc__, "onnx numpy abs")
- self.assertTrue(hasattr(test_abs, 'compiled'))
- self.assertIsInstance(test_abs.compiled, ONC)
- rep = repr(test_abs.compiled)
- self.assertStartsWith("OnnxNumpyCompiler(", rep)
-
- def test_py_abs_add(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_add(x)
- self.assertEqualArray(y, numpy.abs(x) + x)
-
- def test_py_abs_addm(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_addm(x, x)
- self.assertEqualArray(y, numpy.abs(x) + x)
-
- def test_py_abs_add_cst(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_add2(x)
- self.assertEqualArray(y, numpy.abs(x) + 2)
-
- def test_py_abs_add4(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_add4(x)
- text = str(test_abs_add4.compiled.onnx_).split('op_type: "Mul"')
- self.assertEqual(len(text), 3)
- self.assertEqualArray(y, (x * x) * (x * x))
-
- def test_py_abs_sub(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_sub(x)
- self.assertEqualArray(y, numpy.abs(x) - x)
-
- def test_py_abs_mul(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_mul(x)
- self.assertEqualArray(y, numpy.abs(x) * x)
-
- def test_py_abs_mod(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_mod(x)
- self.assertEqualArray(y, numpy.abs(x) % 2)
-
- def test_py_abs_pox(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_pow(x)
- self.assertEqualArray(y, numpy.abs(x) ** 2)
-
- def test_py_abs_matmul(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_matmul(x)
- self.assertEqualArray(y, numpy.abs(x) @ x)
-
- def test_py_abs_div(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_div(x)
- self.assertEqualArray(y, numpy.abs(x) / x)
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.int64)
- y = test_abs_div(x)
- self.assertEqualArray(y, numpy.abs(x) / x)
-
- def test_py_abs_idiv(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_idiv(x)
- self.assertEqualArray(y, numpy.abs(x) // x)
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.int64)
- y = test_abs_idiv(x)
- self.assertEqualArray(y, numpy.abs(x) // x)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_equal(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_equal(x)
- self.assertEqualArray(y, numpy.abs(x) == x)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_not_equal(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_not_equal(x)
- self.assertEqualArray(y, numpy.abs(x) != x)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_greater(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_greater(x)
- self.assertEqualArray(y, numpy.abs(x) > x)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_greater_or_equal(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_greater_or_equal(x)
- self.assertEqualArray(y, numpy.abs(x) >= x)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_less(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_less(x)
- self.assertEqualArray(y, numpy.abs(x) < x)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_less_or_equal(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_less_or_equal(x)
- self.assertEqualArray(y, numpy.abs(x) <= x)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_and(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_and(x)
- self.assertEqualArray(
- y, (numpy.abs(x) < x) & (numpy.abs(x) < 0))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_and2(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_and2(x)
- self.assertEqualArray(
- y, (numpy.abs(x) < x) & (numpy.abs(x) < 0))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_or(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_or(x)
- self.assertEqualArray(
- y, (numpy.abs(x) < x) | (numpy.abs(x) < 0))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_or2(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_or2(x)
- self.assertEqualArray(
- y, (numpy.abs(x) < x) | (numpy.abs(x) < 0))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_sum1(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_sum1(x)
- self.assertEqualArray(y, numpy.sum(numpy.abs(x), axis=0))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_sum2(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_sum2(x)
- self.assertEqualArray(y, numpy.sum(numpy.abs(x), axis=1, keepdims=1))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_transpose_t(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_transpose_t(x)
- self.assertEqualArray(y, numpy.abs(x).T)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_cast(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_cast(x)
- self.assertEqualArray(y, numpy.abs(x).astype(numpy.int64))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_reshape(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_reshape(x)
- self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1)))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_reshape_11(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_reshape(x)
- self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1)))
- compiled = test_abs_reshape.compiled
- self.assertNotIn("version: 11", str(compiled.onnx_))
- y = test_abs_reshape_11(x)
- self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1)))
- compiled = test_abs_reshape_11.compiled
- self.assertIn("version: 11", str(compiled.onnx_))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_slice(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_slice(x)
- self.assertEqualArray(y, numpy.abs(x)[:, 1])
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_slice23(self):
- x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
- y = test_abs_slice23(x)
- self.assertEqualArray(y, numpy.abs(x)[::2, ::3])
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_neg(self):
- x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
- y = test_abs_neg(x)
- self.assertEqualArray(y, -numpy.abs(x))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_not(self):
- x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
- y = test_abs_not(x)
- self.assertEqualArray(y, numpy.abs(x) <= 0)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_filter(self):
- x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
- y = test_abs_filter(x)
- self.assertEqualArray(y, numpy.abs(x)[x[:, 0] > 15])
-
- @ignore_warnings(DeprecationWarning)
- def test_py_log(self):
- x = numpy.array([[6.1, 5], [3.5, 7.8]], dtype=numpy.float32)
- y = test_log(x)
- self.assertEqualArray(y, numpy.log(x))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_log_multi(self):
- x = numpy.array([[6.1, -5], [-3.5, 7.8]], dtype=numpy.float32)
- y = test_abs_log_multi(x)
- self.assertEqualArray(y, numpy.log(numpy.abs(x)))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_log_multi_dtype(self):
- x = numpy.array([[6.1, -5], [-3.5, 7.8]], dtype=numpy.float32)
- y = test_abs_log_multi_dtype(x)
- self.assertEqualArray(y, numpy.log(numpy.abs(x) + 1))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_shape(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_shape(x)
- self.assertEqualArray(y, numpy.abs(x).shape)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_size(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_size(x)
- self.assertEqualArray(y, numpy.abs(x).size)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_flatten(self):
- x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
- y = test_abs_flatten(x)
- self.assertEqualArray(y, numpy.abs(x).flatten())
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_flatten2(self):
- x = numpy.array([[[6.11, -51], [3.51, -7.81]],
- [[6.1, -5], [3.5, -7.8]]], dtype=numpy.float32)
- y = test_abs_flatten2(x)
- self.assertEqualArray(y, numpy.abs(x).flatten().reshape((2, -1)))
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_set1a(self):
- x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0], dtype=numpy.float32)
- y = test_abs_set1a(x)
- temp = numpy.abs(x)
- temp[2] = -1.5
- self.assertEqualArray(y, temp)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_set1b(self):
- x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0], dtype=numpy.float32)
- y = test_abs_set1b(x)
- temp = numpy.abs(x)
- temp[:4] = -1.5
- self.assertEqualArray(y, temp)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_set1c(self):
- x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0], dtype=numpy.float32)
- y = test_abs_set1c(x)
- temp = numpy.abs(x)
- temp[:4:2] = -1.5
- self.assertEqualArray(y, temp)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_set1d(self):
- x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0], dtype=numpy.float32)
- y = test_abs_set1d(x)
- temp = numpy.abs(x)
- temp[:4:2] = [-1.5, -1.6]
- self.assertEqualArray(y, temp)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_set1e(self):
- self.assertIn('op_type: "Shape"', str(test_abs_set1e.compiled.onnx_))
- x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6., -7.],
- dtype=numpy.float32)
- y = test_abs_set1e(x)
- temp = numpy.abs(x)
- temp[2:] = -1.5
- self.assertEqualArray(y, temp)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_set1f(self):
- x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6.],
- dtype=numpy.float32)
- y = test_abs_set1f(x)
- temp = numpy.abs(x)
- temp[3:5] = -1.5
- self.assertEqualArray(y, temp)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_set1g(self):
- x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6.],
- dtype=numpy.float32)
- y = test_abs_set1g(x)
- temp = numpy.abs(x)
- temp[3:] = -1.5
- self.assertEqualArray(y, temp)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_set1h(self):
- x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6.],
- dtype=numpy.float32)
- y = test_abs_set1h(x)
- temp = x.copy()
- temp[x < 0] = -1
- self.assertEqualArray(temp, y)
-
- @ignore_warnings(DeprecationWarning)
- def test_py_abs_set1i(self):
- x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6.],
- dtype=numpy.float32)
- y = test_abs_set1i(x)
- temp = numpy.abs(x)
- self.assertEqualArray(temp, y)
-
-
-if __name__ == "__main__":
- unittest.main()
+# -*- coding: utf-8 -*-
+"""
+@brief test log(time=3s)
+"""
+import unittest
+from typing import Any
+import numpy
+from pyquickhelper.pycode import ExtTestCase, ignore_warnings
+from mlprodict.npy import onnxnumpy, onnxnumpy_default, onnxnumpy_np
+import mlprodict.npy.numpy_onnx_impl as nxnp
+from mlprodict.npy import (
+ OnnxNumpyCompiler as ONC, NDArray, NDArraySameTypeSameShape)
+
+
+@ignore_warnings(DeprecationWarning)
+def get_bool(unused):
+ try:
+ return numpy.bool_
+ except AttributeError:
+ return bool
+
+
+numpy_bool = get_bool(None)
+
+
+@onnxnumpy_default
+def test_abs(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy abs"
+ return nxnp.abs(x)
+
+
+@onnxnumpy_default
+def test_abs_abs(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy abs abs"
+ return nxnp.abs(nxnp.abs(x))
+
+
+@onnxnumpy_default
+def test_abs_add(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy addition"
+ return nxnp.abs(x) + x
+
+
+@onnxnumpy_default
+def test_abs_add4(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy addition"
+ x2 = x * x
+ return x2 * x2
+
+
+@onnxnumpy_default
+def test_abs_addm(x1: NDArray[Any, numpy.float32],
+ x2: NDArray[Any, numpy.float32]
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy addition"
+ return nxnp.abs(x1) + x2
+
+
+@onnxnumpy_default
+def test_abs_add2(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy addition"
+ return nxnp.abs(x) + numpy.float32(2)
+
+
+@onnxnumpy_default
+def test_abs_sub(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy addition"
+ return nxnp.abs(x) - x
+
+
+@onnxnumpy_default
+def test_abs_mul(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy addition"
+ return nxnp.abs(x) * x
+
+
+@onnxnumpy_default
+def test_abs_pow(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy power"
+ return nxnp.abs(x) ** numpy.float32(2)
+
+
+@onnxnumpy_default
+def test_abs_mod(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy modulo"
+ return nxnp.abs(x) % numpy.float32(2)
+
+
+@onnxnumpy_default
+def test_abs_matmul(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy addition"
+ return nxnp.abs(x) @ x
+
+
+@onnxnumpy_default
+def test_abs_div(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy division"
+ return nxnp.abs(x) / x
+
+
+@onnxnumpy_default
+def test_abs_idiv(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.int64]:
+ "onnx numpy int division"
+ return nxnp.abs(x).astype(numpy.int64) // x.astype(numpy.int64)
+
+
+@onnxnumpy_default
+def test_abs_equal(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy equality"
+ return nxnp.abs(x) == x
+
+
+@onnxnumpy_default
+def test_abs_not_equal(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy inequality"
+ return nxnp.abs(x) != x
+
+
+@onnxnumpy_default
+def test_abs_greater(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy greater"
+ return nxnp.abs(x) > x
+
+
+@onnxnumpy_default
+def test_abs_greater_or_equal(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy greater or equal"
+ return nxnp.abs(x) >= x
+
+
+@onnxnumpy_default
+def test_abs_less(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy less"
+ return nxnp.abs(x) < x
+
+
+@onnxnumpy_default
+def test_abs_less_or_equal(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy less or equal"
+ return nxnp.abs(x) <= x
+
+
+@onnxnumpy_default
+def test_abs_and(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy and"
+ return (nxnp.abs(x) < x) and (nxnp.abs(x) < numpy.float32(0))
+
+
+@onnxnumpy_default
+def test_abs_and2(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy and"
+ return (nxnp.abs(x) < x) & (nxnp.abs(x) < numpy.float32(0))
+
+
+@onnxnumpy_default
+def test_abs_or(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy or"
+ return (nxnp.abs(x) < x) or (nxnp.abs(x) < numpy.float32(0))
+
+
+@onnxnumpy_default
+def test_abs_or2(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy_bool]:
+ "onnx numpy or"
+ return (nxnp.abs(x) < x) | (nxnp.abs(x) < numpy.float32(0))
+
+
+@onnxnumpy_default
+def test_abs_sum1(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy sum"
+ return nxnp.sum(nxnp.abs(x), axis=0)
+
+
+@onnxnumpy_default
+def test_abs_sum2(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy sum"
+ return nxnp.sum(nxnp.abs(x), axis=1, keepdims=1)
+
+
+@onnxnumpy_default
+def test_abs_transpose_t(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy transpose T"
+ return nxnp.abs(x).T
+
+
+@onnxnumpy_default
+def test_abs_cast(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.int64]:
+ "onnx numpy cast"
+ return nxnp.abs(x).astype(numpy.int64)
+
+
+@onnxnumpy_default
+def test_abs_reshape(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy reshape"
+ return nxnp.abs(x).reshape((-1, 1))
+
+
+@onnxnumpy(op_version=11)
+def test_abs_reshape_11(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy reshape with opset 11"
+ return nxnp.abs(x).reshape((-1, 1))
+
+
+@onnxnumpy_default
+def test_abs_slice(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy slice 1"
+ return nxnp.abs(x)[:, 1]
+
+
+@onnxnumpy_default
+def test_abs_slice2(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy slice 2"
+ return nxnp.abs(x)[:1, 1]
+
+
+@onnxnumpy_default
+def test_abs_slice23(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy slice 23"
+ return nxnp.abs(x)[::2, ::3]
+
+
+@onnxnumpy_default
+def test_abs_slice_end(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy slice end"
+ return nxnp.abs(x)[1:, :3]
+
+
+@onnxnumpy_default
+def test_abs_gather(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy gather"
+ return nxnp.abs(x)[1]
+
+
+@onnxnumpy_default
+def test_abs_gather2(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy gather"
+ return nxnp.abs(x)[:, 1]
+
+
+@onnxnumpy_default
+def test_abs_neg(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy neg"
+ return - nxnp.abs(x)
+
+
+@onnxnumpy_default
+def test_abs_not(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.bool_]:
+ "onnx numpy not"
+ temp = nxnp.abs(x) > numpy.float32(0)
+ return temp.not_()
+
+
+@onnxnumpy_default
+def test_abs_filter(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy filter"
+ return nxnp.abs(x)[x[:, 0] > numpy.float32(15)]
+
+
+@onnxnumpy_default
+def test_log(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy log"
+ return nxnp.log(x)
+
+
+@onnxnumpy_np(signature=NDArraySameTypeSameShape("floats"))
+def test_abs_log_multi(x):
+ "onnx numpy log multiple type"
+ return nxnp.log(nxnp.abs(x))
+
+
+@onnxnumpy_np(signature=NDArraySameTypeSameShape("floats"))
+def test_abs_log_multi_dtype(x):
+ "onnx numpy log multiple type"
+ return nxnp.log(nxnp.abs(x) + x.dtype(1))
+
+
+@onnxnumpy_default
+def test_abs_shape(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.int64]:
+ "onnx numpy shape"
+ return nxnp.abs(x).shape
+
+
+@onnxnumpy_default
+def test_abs_size(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.int64]:
+ "onnx numpy size"
+ return nxnp.abs(x).size
+
+
+@onnxnumpy_default
+def test_abs_flatten(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.int64]:
+ "onnx numpy flatten"
+ return nxnp.abs(x).flatten()
+
+
+@onnxnumpy_default
+def test_abs_flatten2(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.int64]:
+ "onnx numpy flatten"
+ return nxnp.abs(x).flatten(axis=1)
+
+
+@onnxnumpy_default
+def test_abs_set1a(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy set"
+ temp = nxnp.abs(x).copy()
+ temp[2] = numpy.float32(-1.5)
+ return temp
+
+
+@onnxnumpy_default
+def test_abs_set1b(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy set"
+ temp = nxnp.abs(x).copy()
+ temp[:4] = numpy.float32(-1.5)
+ return temp
+
+
+@onnxnumpy_default
+def test_abs_set1c(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy set"
+ temp = nxnp.abs(x).copy()
+ temp[:4:2] = numpy.float32(-1.5)
+ return temp
+
+
+@onnxnumpy_default
+def test_abs_set1d(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy set"
+ temp = nxnp.abs(x).copy()
+ temp[:4:2] = numpy.array([-1.5, -1.6], dtype=numpy.float32)
+ return temp
+
+
+@onnxnumpy_default
+def test_abs_set1e(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy set"
+ temp = nxnp.abs(x).copy()
+ temp[2:] = numpy.float32(-1.5)
+ return temp
+
+
+@onnxnumpy_default
+def test_abs_set1f(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy set"
+ temp = nxnp.abs(x).copy()
+ temp[3:5] = numpy.float32(-1.5)
+ return temp
+
+
+@onnxnumpy_default
+def test_abs_set1g(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy set"
+ temp = nxnp.abs(x).copy()
+ temp[3:] = numpy.array([-1.5] * 4, dtype=numpy.float32)
+ return temp
+
+
+@onnxnumpy_default
+def test_abs_set1h(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy set"
+ cp = x.copy()
+ cp[x < numpy.float32(0)] = numpy.array([-1], dtype=numpy.float32)
+ return cp
+
+
+@onnxnumpy_default
+def test_abs_set1i(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy set"
+ cp = x.copy()
+ z = x < numpy.float32(0)
+ cp[z] = -x
+ return cp
+
+
+class TestOnnxVariable(ExtTestCase):
+
+ def test_py_abs(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs(x)
+ self.assertEqualArray(y, numpy.abs(x))
+ self.assertEqual(test_abs.__doc__, "onnx numpy abs")
+ self.assertTrue(hasattr(test_abs, 'compiled'))
+ self.assertIsInstance(test_abs.compiled, ONC)
+ rep = repr(test_abs.compiled)
+ self.assertStartsWith("OnnxNumpyCompiler(", rep)
+
+ def test_py_abs_add(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_add(x)
+ self.assertEqualArray(y, numpy.abs(x) + x)
+
+ def test_py_abs_addm(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_addm(x, x)
+ self.assertEqualArray(y, numpy.abs(x) + x)
+
+ def test_py_abs_add_cst(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_add2(x)
+ self.assertEqualArray(y, numpy.abs(x) + 2)
+
+ def test_py_abs_add4(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_add4(x)
+ text = str(test_abs_add4.compiled.onnx_).split('op_type: "Mul"')
+ self.assertEqual(len(text), 3)
+ self.assertEqualArray(y, (x * x) * (x * x))
+
+ def test_py_abs_sub(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_sub(x)
+ self.assertEqualArray(y, numpy.abs(x) - x)
+
+ def test_py_abs_mul(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_mul(x)
+ self.assertEqualArray(y, numpy.abs(x) * x)
+
+ def test_py_abs_mod(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_mod(x)
+ self.assertEqualArray(y, numpy.abs(x) % 2)
+
+ def test_py_abs_pox(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_pow(x)
+ self.assertEqualArray(y, numpy.abs(x) ** 2)
+
+ def test_py_abs_matmul(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_matmul(x)
+ self.assertEqualArray(y, numpy.abs(x) @ x)
+
+ def test_py_abs_div(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_div(x)
+ self.assertEqualArray(y, numpy.abs(x) / x)
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.int64)
+ y = test_abs_div(x)
+ self.assertEqualArray(y, numpy.abs(x) / x)
+
+ def test_py_abs_idiv(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_idiv(x)
+ self.assertEqualArray(y, numpy.abs(x) // x)
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.int64)
+ y = test_abs_idiv(x)
+ self.assertEqualArray(y, numpy.abs(x) // x)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_equal(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_equal(x)
+ self.assertEqualArray(y, numpy.abs(x) == x)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_not_equal(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_not_equal(x)
+ self.assertEqualArray(y, numpy.abs(x) != x)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_greater(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_greater(x)
+ self.assertEqualArray(y, numpy.abs(x) > x)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_greater_or_equal(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_greater_or_equal(x)
+ self.assertEqualArray(y, numpy.abs(x) >= x)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_less(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_less(x)
+ self.assertEqualArray(y, numpy.abs(x) < x)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_less_or_equal(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_less_or_equal(x)
+ self.assertEqualArray(y, numpy.abs(x) <= x)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_and(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_and(x)
+ self.assertEqualArray(
+ y, (numpy.abs(x) < x) & (numpy.abs(x) < 0))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_and2(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_and2(x)
+ self.assertEqualArray(
+ y, (numpy.abs(x) < x) & (numpy.abs(x) < 0))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_or(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_or(x)
+ self.assertEqualArray(
+ y, (numpy.abs(x) < x) | (numpy.abs(x) < 0))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_or2(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_or2(x)
+ self.assertEqualArray(
+ y, (numpy.abs(x) < x) | (numpy.abs(x) < 0))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_sum1(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_sum1(x)
+ self.assertEqualArray(y, numpy.sum(numpy.abs(x), axis=0))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_sum2(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_sum2(x)
+ self.assertEqualArray(y, numpy.sum(numpy.abs(x), axis=1, keepdims=1))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_transpose_t(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_transpose_t(x)
+ self.assertEqualArray(y, numpy.abs(x).T)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_cast(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_cast(x)
+ self.assertEqualArray(y, numpy.abs(x).astype(numpy.int64))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_reshape(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_reshape(x)
+ self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1)))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_reshape_11(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_reshape(x)
+ self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1)))
+ compiled = test_abs_reshape.compiled
+ self.assertNotIn("version: 11", str(compiled.onnx_))
+ y = test_abs_reshape_11(x)
+ self.assertEqualArray(y, numpy.abs(x).reshape((-1, 1)))
+ compiled = test_abs_reshape_11.compiled
+ self.assertIn("version: 11", str(compiled.onnx_))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_slice(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_slice(x)
+ self.assertEqualArray(y, numpy.abs(x)[:, 1])
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_slice23(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_slice23(x)
+ self.assertEqualArray(y, numpy.abs(x)[::2, ::3])
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_slice_end(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_slice_end(x)
+ self.assertEqualArray(y, numpy.abs(x)[1:, :3])
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_gather(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_gather(x)
+ self.assertEqualArray(y, numpy.abs(x)[1])
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_gather2(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_gather2(x)
+ self.assertEqualArray(y, numpy.abs(x)[:, 1])
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_neg(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_neg(x)
+ self.assertEqualArray(y, -numpy.abs(x))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_not(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_not(x)
+ self.assertEqualArray(y, numpy.abs(x) <= 0)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_filter(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_filter(x)
+ self.assertEqualArray(y, numpy.abs(x)[x[:, 0] > 15])
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_log(self):
+ x = numpy.array([[6.1, 5], [3.5, 7.8]], dtype=numpy.float32)
+ y = test_log(x)
+ self.assertEqualArray(y, numpy.log(x))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_log_multi(self):
+ x = numpy.array([[6.1, -5], [-3.5, 7.8]], dtype=numpy.float32)
+ y = test_abs_log_multi(x)
+ self.assertEqualArray(y, numpy.log(numpy.abs(x)))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_log_multi_dtype(self):
+ x = numpy.array([[6.1, -5], [-3.5, 7.8]], dtype=numpy.float32)
+ y = test_abs_log_multi_dtype(x)
+ self.assertEqualArray(y, numpy.log(numpy.abs(x) + 1))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_shape(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_shape(x)
+ self.assertEqualArray(y, numpy.abs(x).shape)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_size(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_size(x)
+ self.assertEqualArray(y, numpy.abs(x).size)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_flatten(self):
+ x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
+ y = test_abs_flatten(x)
+ self.assertEqualArray(y, numpy.abs(x).flatten())
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_flatten2(self):
+ x = numpy.array([[[6.11, -51], [3.51, -7.81]],
+ [[6.1, -5], [3.5, -7.8]]], dtype=numpy.float32)
+ y = test_abs_flatten2(x)
+ self.assertEqualArray(y, numpy.abs(x).flatten().reshape((2, -1)))
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_set1a(self):
+ x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0], dtype=numpy.float32)
+ y = test_abs_set1a(x)
+ temp = numpy.abs(x)
+ temp[2] = -1.5
+ self.assertEqualArray(y, temp)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_set1b(self):
+ x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0], dtype=numpy.float32)
+ y = test_abs_set1b(x)
+ temp = numpy.abs(x)
+ temp[:4] = -1.5
+ self.assertEqualArray(y, temp)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_set1c(self):
+ x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0], dtype=numpy.float32)
+ y = test_abs_set1c(x)
+ temp = numpy.abs(x)
+ temp[:4:2] = -1.5
+ self.assertEqualArray(y, temp)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_set1d(self):
+ x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0], dtype=numpy.float32)
+ y = test_abs_set1d(x)
+ temp = numpy.abs(x)
+ temp[:4:2] = [-1.5, -1.6]
+ self.assertEqualArray(y, temp)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_set1e(self):
+ self.assertIn('op_type: "Shape"', str(test_abs_set1e.compiled.onnx_))
+ x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6., -7.],
+ dtype=numpy.float32)
+ y = test_abs_set1e(x)
+ temp = numpy.abs(x)
+ temp[2:] = -1.5
+ self.assertEqualArray(y, temp)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_set1f(self):
+ x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6.],
+ dtype=numpy.float32)
+ y = test_abs_set1f(x)
+ temp = numpy.abs(x)
+ temp[3:5] = -1.5
+ self.assertEqualArray(y, temp)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_set1g(self):
+ x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6.],
+ dtype=numpy.float32)
+ y = test_abs_set1g(x)
+ temp = numpy.abs(x)
+ temp[3:] = -1.5
+ self.assertEqualArray(y, temp)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_set1h(self):
+ x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6.],
+ dtype=numpy.float32)
+ y = test_abs_set1h(x)
+ temp = x.copy()
+ temp[x < 0] = -1
+ self.assertEqualArray(temp, y)
+
+ @ignore_warnings(DeprecationWarning)
+ def test_py_abs_set1i(self):
+ x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0, -6.],
+ dtype=numpy.float32)
+ y = test_abs_set1i(x)
+ temp = numpy.abs(x)
+ self.assertEqualArray(temp, y)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/_unittests/ut_npy/test_onnx_variable_ort.py b/_unittests/ut_npy/test_onnx_variable_ort.py
index abead9145..daf4b76de 100644
--- a/_unittests/ut_npy/test_onnx_variable_ort.py
+++ b/_unittests/ut_npy/test_onnx_variable_ort.py
@@ -236,6 +236,27 @@ def test_abs_slice23(x: NDArray[Any, numpy.float32],
return nxnp.abs(x)[::2, ::3]
+@onnxnumpy(runtime='onnxruntime1')
+def test_abs_slice_end(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy slice end"
+ return nxnp.abs(x)[1:, :3]
+
+
+@onnxnumpy(runtime='onnxruntime1')
+def test_abs_gather(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy gather"
+ return nxnp.abs(x)[1]
+
+
+@onnxnumpy(runtime='onnxruntime1')
+def test_abs_gather2(x: NDArray[Any, numpy.float32],
+ ) -> NDArray[Any, numpy.float32]:
+ "onnx numpy gather"
+ return nxnp.abs(x)[:, 1]
+
+
@onnxnumpy(runtime='onnxruntime1')
def test_abs_neg(x: NDArray[Any, numpy.float32],
) -> NDArray[Any, numpy.float32]:
@@ -530,6 +551,21 @@ def test_ort_abs_slice23(self):
y = test_abs_slice23(x)
self.assertEqualArray(y, numpy.abs(x)[::2, ::3])
+ def test_ort_abs_slice_end(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_slice_end(x)
+ self.assertEqualArray(y, numpy.abs(x)[1:, :3])
+
+ def test_ort_abs_gather(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_gather(x)
+ self.assertEqualArray(y, numpy.abs(x)[1])
+
+ def test_ort_abs_gather2(self):
+ x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
+ y = test_abs_gather2(x)
+ self.assertEqualArray(y, numpy.abs(x)[:, 1])
+
def test_ort_abs_neg(self):
x = numpy.arange(0, 36).reshape((6, 6)).astype(numpy.float32)
y = test_abs_neg(x)
diff --git a/_unittests/ut_onnxrt/test_onnxrt_validate_bug.py b/_unittests/ut_onnxrt/test_onnxrt_validate_bug.py
index e48a61f81..cddd5dc59 100644
--- a/_unittests/ut_onnxrt/test_onnxrt_validate_bug.py
+++ b/_unittests/ut_onnxrt/test_onnxrt_validate_bug.py
@@ -60,7 +60,8 @@ def test_dict_vectorizer_rfr(self):
x = {k: numpy.float32(v) for k, v in x.items()}
oinf = OnnxInference(model_onnx, runtime='python')
- res3 = oinf.run({input_name: numpy.array([x])}) # , verbose=10, fLOG=print)
+ # , verbose=10, fLOG=print)
+ res3 = oinf.run({input_name: numpy.array([x])})
self.assertEqualFloat(res[0][0, 0], res2["variable1"][0, 0])
self.assertEqualFloat(res[0][0, 0], res3["variable1"][0])
diff --git a/mlprodict/npy/numpy_onnx_impl.py b/mlprodict/npy/numpy_onnx_impl.py
index c262b9981..3b72c4233 100644
--- a/mlprodict/npy/numpy_onnx_impl.py
+++ b/mlprodict/npy/numpy_onnx_impl.py
@@ -43,7 +43,7 @@
OnnxSin, OnnxSinh,
OnnxSqrt,
OnnxSqueeze,
- OnnxTan, OnnxTanh, OnnxTopK,
+ OnnxTan, OnnxTanh, OnnxTopK, OnnxTranspose,
OnnxUnsqueeze,
OnnxWhere)
from .onnx_variable import OnnxVar, MultiOnnxVar as xtuple
@@ -360,6 +360,11 @@ def topk(x, k, axis=-1, largest=1, sorted=1):
sorted=sorted)
+def transpose(x, perm=(1, 0)):
+ "See :epkg:`numpy:transpose`."
+ return OnnxVar(x, op=OnnxTranspose, perm=list(perm))
+
+
def unsqueeze(x, axes):
"See :epkg:`numpy:expand_dims`."
return OnnxVar(x, axes, op=OnnxUnsqueeze)
diff --git a/mlprodict/npy/numpy_onnx_pyrt.py b/mlprodict/npy/numpy_onnx_pyrt.py
index 8adade7f3..174bbb44b 100644
--- a/mlprodict/npy/numpy_onnx_pyrt.py
+++ b/mlprodict/npy/numpy_onnx_pyrt.py
@@ -58,6 +58,7 @@
tan as nx_tan,
tanh as nx_tanh,
topk as nx_topk,
+ transpose as nx_transpose,
unsqueeze as nx_unsqueeze,
vstack as nx_vstack,
where as nx_where,
@@ -342,6 +343,12 @@ def topk(x, k, axis=-1, largest=1, sorted=1):
return nx_topk(x, k, axis=axis, largest=largest, sorted=sorted)
+@onnxnumpy_np(signature=NDArraySameType("all"))
+def transpose(x, perm=(1, 0)):
+ "transpose"
+ return nx_transpose(x, perm=perm)
+
+
@onnxnumpy_np(signature=NDArrayType(("all", numpy.int64)))
def unsqueeze(x, axes):
"unsqueeze"
diff --git a/mlprodict/npy/onnx_numpy_compiler.py b/mlprodict/npy/onnx_numpy_compiler.py
index ab66f56cb..414f43884 100644
--- a/mlprodict/npy/onnx_numpy_compiler.py
+++ b/mlprodict/npy/onnx_numpy_compiler.py
@@ -1,397 +1,400 @@
-"""
-@file
-@brief Implements :epkg:`numpy` functions with onnx and a runtime.
-
-.. versionadded:: 0.6
-"""
-import inspect
-from typing import Any
-import numpy
-from skl2onnx.common.data_types import guess_numpy_type
-from skl2onnx import __max_supported_opset__
-from ..tools.ort_wrapper import InferenceSession
-from ..onnxrt import OnnxInference
-from .onnx_version import FctVersion
-from .onnx_numpy_annotation import get_args_kwargs
-from .onnx_variable import OnnxVar
-
-
-class OnnxNumpyFunction:
- """
- Class wrapping a function build with
- @see cl OnnxNumpyCompiler.
-
- .. versionadded:: 0.6
- """
-
- def __init__(self, compiler, rt, inputs, outputs,
- n_optional, n_variables):
- self.compiler = compiler
- self.inputs = inputs
- self.outputs = outputs
- self.rt = rt
- self.n_optional = n_optional
- self.n_variables = n_variables
- if n_optional < 0:
- raise RuntimeError( # pragma: no cover
- "Wrong configuration, n_optional %r must be >= 0."
- "" % n_optional)
- if n_optional >= len(inputs):
- raise RuntimeError( # pragma: no cover
- "Wrong configuration, n_optional %r must be >= %r "
- "the number of inputs." % (n_optional, len(inputs)))
-
- def _check_(self, *args, **kwargs):
- if self.n_variables > 0:
- return
- if (len(args) < len(self.inputs) - self.n_optional or
- len(args) > len(self.inputs)):
- raise RuntimeError( # pragma: no cover
- "Unexpected number of inputs %d. It should be in "
- "[%r, %r] len(args)=%d n_optional=%d n_variables=%d"
- "\nargs=%s\nkwargs=%s\ninputs=%s" % (
- len(args), len(self.inputs) - self.n_optional,
- len(args), self.n_optional, self.n_variables,
- len(self.inputs), args, kwargs, self.inputs))
-
-
-class OnnxNumpyFunctionOnnxInference(OnnxNumpyFunction):
- """
- Overwrites @see cl OnnxNumpyFunction to run an instance of
- @see cl OnnxInference.
-
- .. versionadded:: 0.6
- """
-
- def __call__(self, *args, **kwargs):
- self._check_(*args, **kwargs)
- inp = {k[0]: a for k, a in zip(self.inputs, args)}
- out = self.rt.run(inp, **kwargs)
- if len(out) != len(self.outputs):
- raise RuntimeError( # pragma: no cover
- "Unexpected number of outputs %d instead of %d." % (
- len(out), len(self.outputs)))
- return tuple([out[o[0]] for o in self.outputs])
-
-
-class OnnxNumpyFunctionInferenceSession(OnnxNumpyFunction):
- """
- Overwrites @see cl OnnxNumpyFunction to run an instance of
- `InferenceSession` from :epkg:`onnxruntime`.
-
- .. versionadded:: 0.6
- """
-
- def __call__(self, *args, **kwargs):
- self._check_(*args, **kwargs)
- if len(kwargs) > 0:
- raise RuntimeError( # pragma: no cover
- "kwargs is not used but it is not empty: %r." % kwargs)
- inp = {k[0]: a for k, a in zip(self.inputs, args)}
- out = self.rt.run(None, inp)
-
- if len(out) != len(self.outputs):
- raise RuntimeError( # pragma: no cover
- "Unexpected number of outputs %d instead of %d." % (
- len(out), len(self.outputs)))
- return tuple(out)
-
-
-class OnnxNumpyCompiler:
- """
- Implements a class which runs onnx graph.
-
- :param fct: a function with annotations which returns an ONNX graph,
- it can also be an ONNX graph.
- :param op_version: :epkg:`ONNX` opset to use, None
- for the latest one
- :param runtime: runtime to choose to execute the onnx graph,
- `python`, `onnxruntime`, `onnxruntime1`
- :param signature: used when the function is not annotated
- :param version: the same function can be instantiated with
- different type, this parameter is None or a numpy type
- if the signature allows multiple types, it must an instance
- of type @see cl FctVersion
- :param fctsig: function used to overwrite the fct signature
- in case this one is using `*args, **kwargs`
-
- .. versionadded:: 0.6
- """
-
- def __init__(self, fct, op_version=None, runtime=None, signature=None,
- version=None, fctsig=None):
- if version is not None and not isinstance(version, FctVersion):
- raise TypeError( # pragma: no cover
- "version must be of Type 'FctVersion' not %s - %s"
- "." % (type(version), version))
- self.fctsig = fctsig
- if op_version is None:
- op_version = __max_supported_opset__
- if hasattr(fct, 'SerializeToString'):
- self.fct_ = None
- self.onnx_ = fct
- else:
- self.fct_ = fct
- if not inspect.isfunction(fct):
- raise TypeError( # pragma: no cover
- "Unexpected type for fct=%r, it must be "
- "function." % type(fct))
- self.onnx_ = None
- self.onnx_ = self._to_onnx(
- op_version=op_version, signature=signature,
- version=version)
- self.runtime_ = self._build_runtime(
- op_version=op_version, runtime=runtime,
- signature=signature, version=version)
- ann = self._parse_annotation(signature=signature, version=version)
- inputs, outputs, kwargs, n_optional, n_variables = ann
- n_opt = 0 if signature is None else signature.n_optional
- args, kwargs2 = get_args_kwargs(self.fctsig or self.fct_, n_opt)
- self.meta_ = dict(op_version=op_version, runtime=runtime,
- signature=signature, version=version,
- inputs=inputs, outputs=outputs,
- kwargs=kwargs, n_optional=n_optional,
- n_variables=n_variables,
- args=args, kwargs2=kwargs2,
- annotations=self.fct_.__annotations__)
-
- def __getstate__(self):
- """
- Serializes everything but function `fct_`.
- Function `fct_` is used to build the onnx graph
- and is not needed anymore.
- """
- return dict(onnx_=self.onnx_, meta_=self.meta_)
-
- def __setstate__(self, state):
- """
- Restores serialized data.
- """
- for k, v in state.items():
- setattr(self, k, v)
- self.runtime_ = self._build_runtime(
- op_version=self.meta_['op_version'],
- runtime=self.meta_['runtime'],
- signature=self.meta_['signature'],
- version=self.meta_['version'])
-
- def __repr__(self):
- "usual"
- if self.fct_ is not None:
- return "%s(%s)" % (self.__class__.__name__, repr(self.fct_))
- if self.onnx_ is not None:
- return "%s(%s)" % (self.__class__.__name__, "... ONNX ... ")
- raise NotImplementedError( # pragma: no cover
- "fct_ and onnx_ are empty.")
-
- def _to_onnx_shape(self, shape):
- if shape is Any or shape is Ellipsis:
- shape = None
- elif isinstance(shape, tuple):
- shape = [None if s is Any or s is Ellipsis else s
- for s in shape]
- else:
- raise RuntimeError( # pragma: no cover
- "Unexpected annotated shape %r." % shape)
- return shape
-
- def _to_onnx_dtype(self, dtype, shape):
- from skl2onnx.common.data_types import _guess_numpy_type
- return _guess_numpy_type(dtype, shape)
-
- def _parse_annotation(self, signature, version):
- """
- Returns the annotations for function `fct_`.
-
- :param signature: needed if the annotation is missing,
- then version might be needed to specify which type
- to use if the signature allows many
- :param version: version inside the many signatures possible
- :return: *tuple(inputs, outputs, kwargs)*, each of them
- is a list of tuple with the name and the dtype,
- *kwargs* is the list of additional parameters
- """
- n_opt = 0 if signature is None else signature.n_optional
- if hasattr(self, 'meta_'):
- args, kwargs = self.meta_['args'], self.meta_['kwargs2']
- else:
- args, kwargs = get_args_kwargs(self.fctsig or self.fct_, n_opt)
- if version is not None:
- nv = len(version) - len(args) - n_opt
- if (signature is not None and not
- signature.n_variables and nv > len(kwargs)):
- raise RuntimeError( # pragma: no cover
- "Mismatch (%d - %d - %d ? %d) between version=%r and kwargs=%r for "
- "function %r, optional argument is %d, "
- "signature=%r." % (
- len(version), len(args), n_opt, len(kwargs),
- version, kwargs, self.fct_,
- signature.n_variables, signature))
- vvers = {} if version.kwargs is None else version.kwargs
- up = {}
- for k, v in zip(kwargs, vvers):
- up[k] = v
- kwargs = kwargs.copy()
- kwargs.update(up)
-
- for k, v in kwargs.items():
- if isinstance(v, (type, numpy.dtype)):
- raise RuntimeError( # pragma: no cover
- "Unexpected value for argument %r: %r from %r." % (
- k, v, kwargs))
-
- if signature is not None:
- inputs, kwargs, outputs, n_optional, n_variables = (
- signature.get_inputs_outputs(args, kwargs, version))
- return inputs, outputs, kwargs, n_optional, n_variables
-
- def _possible_names():
- yield 'y'
- yield 'z' # pragma: no cover
- yield 'o' # pragma: no cover
- for i in range(0, 10000): # pragma: no cover
- yield 'o%d' % i
-
- if hasattr(self, 'meta_'):
- annotations = self.meta_['annotations']
- else:
- annotations = self.fct_.__annotations__
- inputs = []
- outputs = []
- for a in args:
- if a == "op_version":
- continue
- if a not in annotations:
- raise RuntimeError( # pragma: no cover
- "Unable to find annotation for argument %r. "
- "You should annotate the arguments and the results "
- "or specify a signature." % a)
- ann = annotations[a]
- shape, dtype = ann.__args__
- shape = self._to_onnx_shape(shape)
- dtype = self._to_onnx_dtype(dtype, shape)
- inputs.append((a, dtype))
-
- ret = annotations['return']
- names_in = set(inp[0] for inp in inputs)
-
- if isinstance(ret, tuple):
- # multiple outputs
- names_none = set()
- for shape_dtype in ret:
- shape, dtype = shape_dtype.__args__
- shape = self._to_onnx_shape(shape)
- dtype = self._to_onnx_dtype(dtype, shape)
- name_out = None
- for name in _possible_names():
- if name not in names_in and name not in names_none:
- name_out = name
- break
- outputs.append((name_out, dtype))
- names_none.add(name_out)
- return (inputs, outputs, kwargs, 0,
- signature.n_variables if signature is not None else False)
-
- # single outputs
- shape, dtype = ret.__args__
- shape = self._to_onnx_shape(shape)
- dtype = self._to_onnx_dtype(dtype, shape)
- name_out = None
- for name in _possible_names():
- if name not in names_in:
- name_out = name
- break
- outputs.append((name_out, dtype))
- return (inputs, outputs, kwargs, 0,
- signature.n_variables if signature is not None else False)
-
- def _to_onnx(self, op_version=None, signature=None, version=None):
- """
- Returns the onnx graph produced by function `fct_`.
- """
- if self.onnx_ is None and self.fct_ is not None:
- inputs, outputs, kwargs, n_optional, n_variables = ( # pylint: disable=W0612
- self._parse_annotation(
- signature=signature, version=version))
- if ((signature is None or not signature.n_variables) and
- isinstance(version, tuple) and
- len(inputs) > len(version)):
- raise NotImplementedError( # pragma: no cover
- "Mismatch between additional parameters %r "
- "(n_optional=%r) and version %r for function %r from %r."
- "" % (kwargs, n_optional, version, self.fct_,
- getattr(self.fct_, '__module__', None)))
- names_in = [oi[0] for oi in inputs]
- names_out = [oi[0] for oi in outputs]
- names_var = [OnnxVar(n, dtype=guess_numpy_type(dt[1]))
- for n, dt in zip(names_in, inputs)]
-
- if 'op_version' in self.fct_.__code__.co_varnames:
- onx_algebra = self.fct_(
- *names_in, op_version=op_version, **kwargs)
- else:
- onx_var = self.fct_(*names_var, **kwargs)
- if not hasattr(onx_var, 'to_algebra'):
- raise TypeError( # pragma: no cover
- "The function %r to convert must return an instance of "
- "OnnxVar but returns type %r." % (self.fct_, type(onx_var)))
- onx_algebra = onx_var.to_algebra(op_version=op_version)
-
- if isinstance(onx_algebra, str):
- raise RuntimeError( # pragma: no cover
- "Unexpected str type %r." % onx_algebra)
- if isinstance(onx_algebra, tuple):
- raise NotImplementedError( # pragma: no cover
- "Not implemented when the function returns multiple results.")
- if hasattr(onx_algebra, 'to_onnx'):
- # skl2onnx algebra
- onx_algebra.output_names = names_out
- onx = onx_algebra.to_onnx(inputs=inputs,
- target_opset=op_version,
- outputs=outputs)
- self.onnx_ = onx
-
- if self.onnx_ is None:
- raise RuntimeError( # pragma: no cover
- "Unable to get the ONNX graph (class %r, fct_=%r)" % (
- type(self), self.fct_))
- return self.onnx_
-
- def _build_runtime(self, op_version=None, runtime=None,
- signature=None, version=None):
- """
- Creates the runtime for the :epkg:`ONNX` graph.
-
- :param op_version: :epkg:`ONNX` opset to use, None
- for the latest one
- :param runtime: runtime to choose to execute the onnx graph,
- `python`, `onnxruntime`, `onnxruntime1`
- :param signature: used when the function is not annotated
- """
- onx = self._to_onnx(op_version=op_version, signature=signature,
- version=version)
- inputs, outputs, _, n_optional, n_variables = self._parse_annotation(
- signature=signature, version=version)
- if runtime != 'onnxruntime':
- rt = OnnxInference(onx, runtime=runtime)
- self.rt_fct_ = OnnxNumpyFunctionOnnxInference(
- self, rt, inputs=inputs, outputs=outputs,
- n_optional=n_optional, n_variables=n_variables)
- else:
- rt = InferenceSession(onx.SerializeToString())
- self.rt_fct_ = OnnxNumpyFunctionInferenceSession(
- self, rt, inputs=inputs, outputs=outputs,
- n_optional=n_optional, n_variables=n_variables)
- return self.rt_fct_
-
- def __call__(self, *args, **kwargs):
- """
- Executes the function and returns the results.
-
- :param args: arguments
- :return: results
- """
- res = self.rt_fct_(*args, **kwargs)
- if len(res) == 1:
- return res[0]
- return res
+"""
+@file
+@brief Implements :epkg:`numpy` functions with onnx and a runtime.
+
+.. versionadded:: 0.6
+"""
+import inspect
+from typing import Any
+import numpy
+from skl2onnx.common.data_types import guess_numpy_type
+from skl2onnx import __max_supported_opset__
+from ..tools.ort_wrapper import InferenceSession
+from ..onnx_tools.optim._main_onnx_optim import onnx_optimisations
+from ..onnxrt import OnnxInference
+from .onnx_version import FctVersion
+from .onnx_numpy_annotation import get_args_kwargs
+from .onnx_variable import OnnxVar
+
+
+class OnnxNumpyFunction:
+ """
+ Class wrapping a function build with
+ @see cl OnnxNumpyCompiler.
+
+ .. versionadded:: 0.6
+ """
+
+ def __init__(self, compiler, rt, inputs, outputs,
+ n_optional, n_variables):
+ self.compiler = compiler
+ self.inputs = inputs
+ self.outputs = outputs
+ self.rt = rt
+ self.n_optional = n_optional
+ self.n_variables = n_variables
+ if n_optional < 0:
+ raise RuntimeError( # pragma: no cover
+ "Wrong configuration, n_optional %r must be >= 0."
+ "" % n_optional)
+ if n_optional >= len(inputs):
+ raise RuntimeError( # pragma: no cover
+ "Wrong configuration, n_optional %r must be >= %r "
+ "the number of inputs." % (n_optional, len(inputs)))
+
+ def _check_(self, *args, **kwargs):
+ if self.n_variables > 0:
+ return
+ if (len(args) < len(self.inputs) - self.n_optional or
+ len(args) > len(self.inputs)):
+ raise RuntimeError( # pragma: no cover
+ "Unexpected number of inputs %d. It should be in "
+ "[%r, %r] len(args)=%d n_optional=%d n_variables=%d"
+ "\nargs=%s\nkwargs=%s\ninputs=%s" % (
+ len(args), len(self.inputs) - self.n_optional,
+ len(args), self.n_optional, self.n_variables,
+ len(self.inputs), args, kwargs, self.inputs))
+
+
+class OnnxNumpyFunctionOnnxInference(OnnxNumpyFunction):
+ """
+ Overwrites @see cl OnnxNumpyFunction to run an instance of
+ @see cl OnnxInference.
+
+ .. versionadded:: 0.6
+ """
+
+ def __call__(self, *args, **kwargs):
+ self._check_(*args, **kwargs)
+ inp = {k[0]: a for k, a in zip(self.inputs, args)}
+ out = self.rt.run(inp, **kwargs)
+ if len(out) != len(self.outputs):
+ raise RuntimeError( # pragma: no cover
+ "Unexpected number of outputs %d instead of %d." % (
+ len(out), len(self.outputs)))
+ return tuple([out[o[0]] for o in self.outputs])
+
+
+class OnnxNumpyFunctionInferenceSession(OnnxNumpyFunction):
+ """
+ Overwrites @see cl OnnxNumpyFunction to run an instance of
+ `InferenceSession` from :epkg:`onnxruntime`.
+
+ .. versionadded:: 0.6
+ """
+
+ def __call__(self, *args, **kwargs):
+ self._check_(*args, **kwargs)
+ if len(kwargs) > 0:
+ raise RuntimeError( # pragma: no cover
+ "kwargs is not used but it is not empty: %r." % kwargs)
+ inp = {k[0]: a for k, a in zip(self.inputs, args)}
+ out = self.rt.run(None, inp)
+
+ if len(out) != len(self.outputs):
+ raise RuntimeError( # pragma: no cover
+ "Unexpected number of outputs %d instead of %d." % (
+ len(out), len(self.outputs)))
+ return tuple(out)
+
+
+class OnnxNumpyCompiler:
+ """
+ Implements a class which runs onnx graph.
+
+ :param fct: a function with annotations which returns an ONNX graph,
+ it can also be an ONNX graph.
+ :param op_version: :epkg:`ONNX` opset to use, None
+ for the latest one
+ :param runtime: runtime to choose to execute the onnx graph,
+ `python`, `onnxruntime`, `onnxruntime1`
+ :param signature: used when the function is not annotated
+ :param version: the same function can be instantiated with
+ different type, this parameter is None or a numpy type
+ if the signature allows multiple types, it must an instance
+ of type @see cl FctVersion
+ :param fctsig: function used to overwrite the fct signature
+ in case this one is using `*args, **kwargs`
+
+ .. versionadded:: 0.6
+ """
+
+ def __init__(self, fct, op_version=None, runtime=None, signature=None,
+ version=None, fctsig=None):
+ if version is not None and not isinstance(version, FctVersion):
+ raise TypeError( # pragma: no cover
+ "version must be of Type 'FctVersion' not %s - %s"
+ "." % (type(version), version))
+ self.fctsig = fctsig
+ if op_version is None:
+ op_version = __max_supported_opset__
+ if hasattr(fct, 'SerializeToString'):
+ self.fct_ = None
+ self.onnx_ = fct
+ else:
+ self.fct_ = fct
+ if not inspect.isfunction(fct):
+ raise TypeError( # pragma: no cover
+ "Unexpected type for fct=%r, it must be "
+ "function." % type(fct))
+ self.onnx_ = None
+ self.onnx_ = self._to_onnx(
+ op_version=op_version, signature=signature,
+ version=version)
+ self.runtime_ = self._build_runtime(
+ op_version=op_version, runtime=runtime,
+ signature=signature, version=version)
+ ann = self._parse_annotation(signature=signature, version=version)
+ inputs, outputs, kwargs, n_optional, n_variables = ann
+ n_opt = 0 if signature is None else signature.n_optional
+ args, kwargs2 = get_args_kwargs(self.fctsig or self.fct_, n_opt)
+ self.meta_ = dict(op_version=op_version, runtime=runtime,
+ signature=signature, version=version,
+ inputs=inputs, outputs=outputs,
+ kwargs=kwargs, n_optional=n_optional,
+ n_variables=n_variables,
+ args=args, kwargs2=kwargs2,
+ annotations=self.fct_.__annotations__)
+
+ def __getstate__(self):
+ """
+ Serializes everything but function `fct_`.
+ Function `fct_` is used to build the onnx graph
+ and is not needed anymore.
+ """
+ return dict(onnx_=self.onnx_, meta_=self.meta_)
+
+ def __setstate__(self, state):
+ """
+ Restores serialized data.
+ """
+ for k, v in state.items():
+ setattr(self, k, v)
+ self.runtime_ = self._build_runtime(
+ op_version=self.meta_['op_version'],
+ runtime=self.meta_['runtime'],
+ signature=self.meta_['signature'],
+ version=self.meta_['version'])
+
+ def __repr__(self):
+ "usual"
+ if self.fct_ is not None:
+ return "%s(%s)" % (self.__class__.__name__, repr(self.fct_))
+ if self.onnx_ is not None:
+ return "%s(%s)" % (self.__class__.__name__, "... ONNX ... ")
+ raise NotImplementedError( # pragma: no cover
+ "fct_ and onnx_ are empty.")
+
+ def _to_onnx_shape(self, shape):
+ if shape is Any or shape is Ellipsis:
+ shape = None
+ elif isinstance(shape, tuple):
+ shape = [None if s is Any or s is Ellipsis else s
+ for s in shape]
+ else:
+ raise RuntimeError( # pragma: no cover
+ "Unexpected annotated shape %r." % shape)
+ return shape
+
+ def _to_onnx_dtype(self, dtype, shape):
+ from skl2onnx.common.data_types import _guess_numpy_type
+ return _guess_numpy_type(dtype, shape)
+
+ def _parse_annotation(self, signature, version):
+ """
+ Returns the annotations for function `fct_`.
+
+ :param signature: needed if the annotation is missing,
+ then version might be needed to specify which type
+ to use if the signature allows many
+ :param version: version inside the many signatures possible
+ :return: *tuple(inputs, outputs, kwargs)*, each of them
+ is a list of tuple with the name and the dtype,
+ *kwargs* is the list of additional parameters
+ """
+ n_opt = 0 if signature is None else signature.n_optional
+ if hasattr(self, 'meta_'):
+ args, kwargs = self.meta_['args'], self.meta_['kwargs2']
+ else:
+ args, kwargs = get_args_kwargs(self.fctsig or self.fct_, n_opt)
+ if version is not None:
+ nv = len(version) - len(args) - n_opt
+ if (signature is not None and not
+ signature.n_variables and nv > len(kwargs)):
+ raise RuntimeError( # pragma: no cover
+ "Mismatch (%d - %d - %d ? %d) between version=%r and kwargs=%r for "
+ "function %r, optional argument is %d, "
+ "signature=%r." % (
+ len(version), len(args), n_opt, len(kwargs),
+ version, kwargs, self.fct_,
+ signature.n_variables, signature))
+ vvers = {} if version.kwargs is None else version.kwargs
+ up = {}
+ for k, v in zip(kwargs, vvers):
+ up[k] = v
+ kwargs = kwargs.copy()
+ kwargs.update(up)
+
+ for k, v in kwargs.items():
+ if isinstance(v, (type, numpy.dtype)):
+ raise RuntimeError( # pragma: no cover
+ "Unexpected value for argument %r: %r from %r." % (
+ k, v, kwargs))
+
+ if signature is not None:
+ inputs, kwargs, outputs, n_optional, n_variables = (
+ signature.get_inputs_outputs(args, kwargs, version))
+ return inputs, outputs, kwargs, n_optional, n_variables
+
+ def _possible_names():
+ yield 'y'
+ yield 'z' # pragma: no cover
+ yield 'o' # pragma: no cover
+ for i in range(0, 10000): # pragma: no cover
+ yield 'o%d' % i
+
+ if hasattr(self, 'meta_'):
+ annotations = self.meta_['annotations']
+ else:
+ annotations = self.fct_.__annotations__
+ inputs = []
+ outputs = []
+ for a in args:
+ if a == "op_version":
+ continue
+ if a not in annotations:
+ raise RuntimeError( # pragma: no cover
+ "Unable to find annotation for argument %r. "
+ "You should annotate the arguments and the results "
+ "or specify a signature." % a)
+ ann = annotations[a]
+ shape, dtype = ann.__args__
+ shape = self._to_onnx_shape(shape)
+ dtype = self._to_onnx_dtype(dtype, shape)
+ inputs.append((a, dtype))
+
+ ret = annotations['return']
+ names_in = set(inp[0] for inp in inputs)
+
+ if isinstance(ret, tuple):
+ # multiple outputs
+ names_none = set()
+ for shape_dtype in ret:
+ shape, dtype = shape_dtype.__args__
+ shape = self._to_onnx_shape(shape)
+ dtype = self._to_onnx_dtype(dtype, shape)
+ name_out = None
+ for name in _possible_names():
+ if name not in names_in and name not in names_none:
+ name_out = name
+ break
+ outputs.append((name_out, dtype))
+ names_none.add(name_out)
+ return (inputs, outputs, kwargs, 0,
+ signature.n_variables if signature is not None else False)
+
+ # single outputs
+ shape, dtype = ret.__args__
+ shape = self._to_onnx_shape(shape)
+ dtype = self._to_onnx_dtype(dtype, shape)
+ name_out = None
+ for name in _possible_names():
+ if name not in names_in:
+ name_out = name
+ break
+ outputs.append((name_out, dtype))
+ return (inputs, outputs, kwargs, 0,
+ signature.n_variables if signature is not None else False)
+
+ def _to_onnx(self, op_version=None, signature=None, version=None):
+ """
+ Returns the onnx graph produced by function `fct_`.
+ """
+ if self.onnx_ is None and self.fct_ is not None:
+ inputs, outputs, kwargs, n_optional, n_variables = ( # pylint: disable=W0612
+ self._parse_annotation(
+ signature=signature, version=version))
+ if ((signature is None or not signature.n_variables) and
+ isinstance(version, tuple) and
+ len(inputs) > len(version)):
+ raise NotImplementedError( # pragma: no cover
+ "Mismatch between additional parameters %r "
+ "(n_optional=%r) and version %r for function %r from %r."
+ "" % (kwargs, n_optional, version, self.fct_,
+ getattr(self.fct_, '__module__', None)))
+ names_in = [oi[0] for oi in inputs]
+ names_out = [oi[0] for oi in outputs]
+ names_var = [OnnxVar(n, dtype=guess_numpy_type(dt[1]))
+ for n, dt in zip(names_in, inputs)]
+
+ if 'op_version' in self.fct_.__code__.co_varnames:
+ onx_algebra = self.fct_(
+ *names_in, op_version=op_version, **kwargs)
+ else:
+ onx_var = self.fct_(*names_var, **kwargs)
+ if not hasattr(onx_var, 'to_algebra'):
+ raise TypeError( # pragma: no cover
+ "The function %r to convert must return an instance of "
+ "OnnxVar but returns type %r." % (self.fct_, type(onx_var)))
+ onx_algebra = onx_var.to_algebra(op_version=op_version)
+
+ if isinstance(onx_algebra, str):
+ raise RuntimeError( # pragma: no cover
+ "Unexpected str type %r." % onx_algebra)
+ if isinstance(onx_algebra, tuple):
+ raise NotImplementedError( # pragma: no cover
+ "Not implemented when the function returns multiple results.")
+ if hasattr(onx_algebra, 'to_onnx'):
+ # skl2onnx algebra
+ onx_algebra.output_names = names_out
+ onx = onx_algebra.to_onnx(inputs=inputs,
+ target_opset=op_version,
+ outputs=outputs)
+ # optimisation
+ onx_optimized = onnx_optimisations(onx)
+ self.onnx_ = onx_optimized
+
+ if self.onnx_ is None:
+ raise RuntimeError( # pragma: no cover
+ "Unable to get the ONNX graph (class %r, fct_=%r)" % (
+ type(self), self.fct_))
+ return self.onnx_
+
+ def _build_runtime(self, op_version=None, runtime=None,
+ signature=None, version=None):
+ """
+ Creates the runtime for the :epkg:`ONNX` graph.
+
+ :param op_version: :epkg:`ONNX` opset to use, None
+ for the latest one
+ :param runtime: runtime to choose to execute the onnx graph,
+ `python`, `onnxruntime`, `onnxruntime1`
+ :param signature: used when the function is not annotated
+ """
+ onx = self._to_onnx(op_version=op_version, signature=signature,
+ version=version)
+ inputs, outputs, _, n_optional, n_variables = self._parse_annotation(
+ signature=signature, version=version)
+ if runtime != 'onnxruntime':
+ rt = OnnxInference(onx, runtime=runtime)
+ self.rt_fct_ = OnnxNumpyFunctionOnnxInference(
+ self, rt, inputs=inputs, outputs=outputs,
+ n_optional=n_optional, n_variables=n_variables)
+ else:
+ rt = InferenceSession(onx.SerializeToString())
+ self.rt_fct_ = OnnxNumpyFunctionInferenceSession(
+ self, rt, inputs=inputs, outputs=outputs,
+ n_optional=n_optional, n_variables=n_variables)
+ return self.rt_fct_
+
+ def __call__(self, *args, **kwargs):
+ """
+ Executes the function and returns the results.
+
+ :param args: arguments
+ :return: results
+ """
+ res = self.rt_fct_(*args, **kwargs)
+ if len(res) == 1:
+ return res[0]
+ return res
diff --git a/mlprodict/npy/onnx_variable.py b/mlprodict/npy/onnx_variable.py
index bbd44e5ac..4adf2d839 100644
--- a/mlprodict/npy/onnx_variable.py
+++ b/mlprodict/npy/onnx_variable.py
@@ -1,709 +1,759 @@
-"""
-@file
-@brief Intermediate class between :epkg:`numpy` and :epkg:`onnx`.
-
-.. versionadded:: 0.6
-"""
-import numpy
-from onnx.helper import make_tensor
-from skl2onnx.common.data_types import guess_numpy_type
-from skl2onnx.common._topology import Variable # pylint: disable=E0611,E0001
-from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
- OnnxAdd, OnnxAnd,
- OnnxCast, OnnxConstantOfShape,
- OnnxDiv,
- OnnxEqual,
- OnnxFlatten,
- OnnxGather, OnnxGreater, OnnxGreaterOrEqual,
- OnnxIdentity,
- OnnxLess, OnnxLessOrEqual,
- OnnxMatMul, OnnxMod, OnnxMul,
- OnnxNeg, OnnxNot,
- OnnxOr,
- OnnxPow,
- OnnxReduceSum, OnnxReshape,
- OnnxScatterElements, OnnxShape, OnnxSize, OnnxSlice,
- OnnxSqueeze, OnnxSub,
- OnnxTopK, OnnxTranspose,
- OnnxWhere)
-from skl2onnx.algebra.onnx_operator import OnnxOperatorItem
-from skl2onnx.common.data_types import _guess_numpy_type
-from ..onnx_tools.onnx2py_helper import guess_proto_dtype
-
-
-try:
- numpy_bool = numpy.bool_
-except AttributeError: # pragma: no cover
- numpy_bool = bool
-try:
- numpy_str = numpy.str_
-except AttributeError: # pragma: no cover
- numpy_str = str
-
-
-class OnnxVar:
- """
- Variables used into :epkg:`onnx` computation.
-
- :param inputs: variable name or object
- :param op: :epkg:`ONNX` operator
- :param select_output: if multiple output are returned by
- ONNX operator *op*, it takes only one specifed by this
- argument
- :param dtype: specifies the type of the variable
- held by this class (*op* is None) in that case
- :param kwargs: addition argument to give operator *op*
-
- .. versionadded:: 0.6
- """
-
- def __init__(self, *inputs, op=None, select_output=None,
- dtype=None, **kwargs):
- self.inputs = inputs
- self.select_output = select_output
- self.onnx_op = op
- self.alg_ = None
- self.onnx_op_kwargs = kwargs
- if dtype is not None and (op is not None or len(inputs) != 1):
- raise RuntimeError( # pragma: no cover
- "dtype can only be used if op is None or len(inputs) == 1.")
- for i, inp in enumerate(self.inputs):
- if isinstance(inp, type):
- raise TypeError( # pragma: no cover
- "Unexpected type for input %d - %r." % (i, inp))
- self.dtype = self._guess_dtype(dtype)
-
- def _guess_dtype(self, dtype):
- "Guesses dtype when not specified."
- if dtype is not None:
- return dtype
- dtypes = []
- for i, inp in enumerate(self.inputs):
- if isinstance(inp, str):
- return None
- if isinstance(inp, numpy.ndarray):
- dtypes.append(inp.dtype)
- elif isinstance(inp, Variable):
- dt = guess_numpy_type(inp.type)
- dtypes.append(dt)
- elif isinstance(inp, OnnxVar):
- dtypes.append(inp.dtype)
- elif isinstance(inp, MultiOnnxVar):
- dtypes.append(inp._guess_dtype(dtype))
- elif isinstance(inp, (numpy.float32, numpy.float64, numpy.int32,
- numpy.int64)):
- dtypes.append(inp.dtype)
- elif isinstance(inp, numpy_str):
- dtypes.append(numpy_str)
- elif isinstance(inp, numpy_bool):
- dtypes.append(numpy_bool)
- elif isinstance(inp, int):
- dtypes.append(numpy.int64) # pragma: no cover
- elif isinstance(inp, float):
- dtypes.append(numpy.float64)
- elif hasattr(inp, 'fit'):
- # scikit-learn model
- continue
- else:
- raise TypeError( # pragma: no cover
- "Unexpected type for input %i type=%r." % (i, type(inp)))
- dtypes = [_ for _ in dtypes if _ is not None]
- unique = set(dtypes)
- if len(unique) != 1:
- return None
- return dtypes[0]
-
- def __repr__(self):
- "usual"
- args = []
- for inp in self.inputs:
- args.append(repr(inp))
- if self.onnx_op is not None:
- if isinstance(self.onnx_op, str):
- args.append("op=%r" % self.onnx_op)
- else:
- args.append("op=%s" % self.onnx_op.__name__)
- if self.select_output is not None:
- args.append("select_output=%r" % self.select_output)
- if self.dtype is not None and self.dtype != self._guess_dtype(None):
- args.append("dtype=%r" % self.dtype)
- for k, v in sorted(self.onnx_op_kwargs.items()):
- args.append("%s=%r" % (k, v))
- res = "%s(%s)" % (self.__class__.__name__, ", ".join(args))
- return res
-
- def to_algebra(self, op_version=None):
- """
- Converts the variable into an operator.
- """
- if self.alg_ is None:
- if self.onnx_op is None:
- if len(self.inputs) != 1:
- raise RuntimeError( # pragma: no cover
- "Unexpected number of inputs, 1 expected, "
- "got {} instead.".format(self.inputs))
- if self.dtype is None or hasattr(self.inputs[0], 'onnx_name'):
- self.alg_ = self.inputs[0]
- else:
- self.alg_ = (
- self.inputs[0], _guess_numpy_type(self.dtype, None))
- else:
- if isinstance(self.onnx_op, str):
- var = self._custom_op(*self.inputs, op_version=op_version,
- **self.onnx_op_kwargs)
- alg = var.to_algebra(op_version=op_version)
- if not hasattr(self, 'alg_'):
- raise RuntimeError( # pragma: no cover
- "Missing attribute 'alg_'.")
- self.alg_ = alg
- return alg
-
- new_inputs = []
- for inp in self.inputs:
- if hasattr(inp, 'fit'):
- # scikit-learn model
- new_inputs.append(inp)
- elif isinstance(inp, (
- int, float, str, numpy.ndarray, numpy.int32,
- numpy.int64, numpy.float32, numpy.float64,
- numpy_bool, numpy_str, numpy.int8, numpy.uint8,
- numpy.int16, numpy.uint16, numpy.uint32, numpy.uint64)):
- new_inputs.append(inp)
- else:
- new_inputs.append(
- inp.to_algebra(op_version=op_version))
-
- res = self.onnx_op(*new_inputs, op_version=op_version,
- **self.onnx_op_kwargs)
- if self.select_output is None:
- self.alg_ = res
- else:
- self.alg_ = res[self.select_output]
- return self.alg_
-
- def _custom_op(self, *args, op_version=None, runtime=None, **kwargs):
- """
- This could be handled before a call to this method
- but this method can change the conversion of an non-existing
- operator depending on the given opset.
- """
- if self.onnx_op == 'filter':
- return self._custom_op_filter(*args, op_version=op_version,
- runtime=runtime, **kwargs)
- raise NotImplementedError( # pragma: no cover
- "Unexpected custom operator %r." % self.onnx_op)
-
- def _custom_op_filter(self, *args, op_version=None, runtime=None, **kwargs):
- """
- This could be handled before a call to this method
- but this method can change the conversion of an non-existing
- operator depending on the given opset.
- """
- if len(args) != 2:
- raise RuntimeError( # pragma: no cover
- "Custom op 'filter' expects two inputs not %r." % len(args))
- if len(kwargs) != 0:
- raise RuntimeError( # pragma: no cover
- "Custom op 'filter' expects no arguments but got %r." % kwargs)
- mat, index = args
- cast = OnnxVar(index.astype(numpy.int64), op=OnnxSqueeze)
- n1 = OnnxVar(cast, op=OnnxReduceSum, keepdims=1)
- indices = OnnxVar(cast, n1, op=OnnxTopK, select_output=1)
- return OnnxVar(mat, indices, op=OnnxGather)
-
- @property
- def T(self):
- "Transpose."
- return OnnxVar(self, op=OnnxTranspose)
-
- def astype(self, dtype):
- "Cast"
- return OnnxVar(self, op=OnnxCast, to=guess_proto_dtype(dtype))
-
- @property
- def shape(self):
- "Shape"
- return OnnxVar(self, op=OnnxShape)
-
- @property
- def size(self):
- "Size"
- return OnnxVar(self, op=OnnxSize)
-
- def reshape(self, shape):
- "Reshape"
- if isinstance(shape, (tuple, list)):
- shape = numpy.array(shape, dtype=numpy.int64)
- return OnnxVar(self, shape, op=OnnxReshape)
-
- def _make_array(self, y):
- """Converts *y* into an array if not."""
- if hasattr(y, 'dtype') and not isinstance(y, (numpy.ndarray, OnnxVar)):
- return numpy.full((1, ), y, dtype=y.dtype)
- if isinstance(y, (float, int, str)):
- return numpy.array([y])
- return y
-
- def __add__(self, y):
- "Addition."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxAdd)
-
- def __sub__(self, y):
- "Subtraction."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxSub)
-
- def __mul__(self, y):
- "Multiplication."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxMul)
-
- def __pow__(self, y):
- "Power."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxPow)
-
- def __mod__(self, y):
- "Modulo."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxMod)
-
- def __matmul__(self, y):
- "Matrix multiplication."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxMatMul)
-
- def __truediv__(self, y):
- "Division, no difference between `/` and `//`."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxDiv)
-
- def __floordiv__(self, y):
- "Division, no difference between `/` and `//`."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxDiv)
-
- def __eq__(self, y):
- "Equality."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxEqual)
-
- def __ne__(self, y):
- "Difference."
- y = self._make_array(y)
- return OnnxVar(OnnxVar(self, y, op=OnnxEqual), op=OnnxNot)
-
- def __ge__(self, y):
- "Greater or Equal."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxGreaterOrEqual)
-
- def __gt__(self, y):
- "Greater."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxGreater)
-
- def __le__(self, y):
- "Less or Equal."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxLessOrEqual)
-
- def __lt__(self, y):
- "Less."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxLess)
-
- def __and__(self, y):
- "And."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxAnd)
-
- def __or__(self, y):
- "And."
- y = self._make_array(y)
- return OnnxVar(self, y, op=OnnxOr)
-
- def not_(self):
- "Not."
- return OnnxVar(self, op=OnnxNot)
-
- def __neg__(self):
- "Neg."
- return OnnxVar(self, op=OnnxNeg)
-
- def __getitem__(self, index):
- """
- Deals with multiple scenarios.
- * *index* is an integer or a slice, a tuple of integers and slices,
- example: `[0, 1]`, `[:5, :6]`, `[::2]` (**scenario 1**)
- * *index* is an *ONNX* object (more precisely an instance of
- @see cl OnnxVar), then the method assumes it is an array of
- boolean to select a subset of the tensor along the first axis,
- example: `mat[mat == 0]` (**scenario 2**)
- """
- if isinstance(index, OnnxVar):
- # scenario 2
- return OnnxVar(self, index, op='filter')
-
- if not isinstance(index, tuple):
- index = (index, )
-
- # scenario 1
- starts = []
- ends = []
- axes = []
- steps = []
- axis_squeeze = []
- for i, ind in enumerate(index):
- if isinstance(ind, int):
- starts.append(ind)
- ends.append(ind + 1)
- axes.append(i)
- steps.append(1)
- axis_squeeze.append(i)
- continue
- if isinstance(ind, slice):
- if ind.start is None and ind.stop is None and ind.step is None:
- continue
- start = 0 if ind.start is None else ind.start
- end = -1 if ind.stop is None else ind.stop
- step = 1 if ind.step is None else ind.step
- starts.append(start)
- ends.append(end)
- axes.append(i)
- steps.append(step)
- continue
- raise NotImplementedError( # pragma: no cover
- "Not implemented for type %r." % type(ind))
- if max(steps) == min(steps) == 1:
- steps = None
- else:
- steps = numpy.array(steps, dtype=numpy.int64)
- starts = numpy.array(starts, dtype=numpy.int64)
- ends = numpy.array(ends, dtype=numpy.int64)
- axes = numpy.array(axes, dtype=numpy.int64)
- if steps is None:
- sliced = OnnxVar(self, starts, ends, axes, op=OnnxSlice)
- else:
- sliced = OnnxVar(self, starts, ends, axes, steps, op=OnnxSlice)
- if len(axis_squeeze) > 0:
- return OnnxVar(
- sliced, numpy.array(axis_squeeze, dtype=numpy.int64),
- op=OnnxSqueeze)
- return sliced
-
- def __setitem__(self, index, value):
- """
- Only supports vectors (1D tensor).
- * *index* is an integer or a slice, a tuple of integers and slices,
- example: `[0]`, `[:5]`, `[::2]` (**scenario 1**)
- * *index* is an *ONNX* object (more precisely an instance of
- @see cl OnnxVar), then the method assumes it is an array of
- boolean to select a subset of the tensor along the first axis,
- example: `mat[mat == 0]` (**scenario 2**)
- This processing is applied before the operator it contains.
- A copy should be made (Identity node or copy method).
- """
- if self.onnx_op is not None and self.onnx_op is not OnnxIdentity:
- raise RuntimeError( # pragma: no cover
- "A copy should be made before setting new values on a matrix. "
- "Method copy() would do that.")
-
- if isinstance(index, OnnxVar):
- # scenario 2, example: cp[x < 0] = -1
- return self._setitem2i_(index, value)
- elif not isinstance(index, tuple):
- index = (index, )
-
- for i in index:
- if isinstance(i, OnnxVar):
- raise NotImplementedError( # pragma: no cover
- "Unable to handle case such as cp[0, x < 0] = -1.")
-
- # scenario 1
- if len(index) == 1:
- return self._setitem1i_(index[0], value)
- raise NotImplementedError( # pragma: no cover
- "Indices in %d dimensions are not implemented yet." % len(index))
-
- def _setitem1i_(self, index, value):
- sl = None
- if isinstance(index, slice):
- start = 0 if index.start is None else index.start
- stop = index.stop
- step = index.step
- sl = [start, stop, step]
- elif isinstance(index, int):
- sl = [index, index + 1, 1]
- else:
- raise NotImplementedError( # pragma: no cover
- "Unable to assign new values due to unexpected type %r."
- "" % type(index))
-
- if sl[1] is None and isinstance(value, numpy.ndarray):
- sl[1] = sl[0] + value.size
- if sl[1] is None:
- if sl[2] is not None and sl[2] != 1:
- raise NotImplementedError( # pragma: no cover
- "If the length is not known, step must be 1 not %d." % sl[2])
- value = make_tensor(
- "value", guess_proto_dtype(value.dtype), (1, ), [value]) # pylint: disable=E1101
- inp = self.inputs[0]
- if not isinstance(inp, OnnxVar):
- raise RuntimeError( # pragma: no cover
- "Input must be an instance of OnnxVar not %r." % type(inp))
- cst = OnnxVar(inp.shape, op=OnnxConstantOfShape, value=value)
- ext = inp[:sl[0]]
- indices = numpy.arange(0, sl[0]).astype(numpy.int64)
- add_step = OnnxVar(cst, indices, ext,
- op=OnnxScatterElements, axis=0)
- else:
- indices = numpy.arange(sl[0], sl[1], sl[2]).astype(numpy.int64)
- if isinstance(value, numpy.ndarray):
- values = value
- else:
- values = numpy.full(indices.shape, value)
- add_step = OnnxVar(self.inputs[0], indices, values,
- op=OnnxScatterElements, axis=0)
-
- self.inputs = [add_step]
- return self
-
- def _setitem2i_(self, index, value):
- add_step = OnnxVar(index, value, self.inputs[0], op=OnnxWhere)
- self.inputs = [add_step]
- return self
-
- def copy(self):
- """
- Returns a copy of self (use of Identity node).
- """
- return OnnxVar(self, op=OnnxIdentity)
-
- def flatten(self, axis=0):
- """
- Flattens a matrix (see :epkg:`numpy:ndarray:flatten`).
-
- :param axis: only flatten from axis to the end.
- :return: @see cl OnnxVariable
- """
- fl = OnnxVar(self, op=OnnxFlatten, axis=axis)
- if axis == 0:
- return OnnxVar(fl, numpy.array([0], dtype=numpy.int64),
- op=OnnxSqueeze)
- return fl
-
-
-class TupleOnnxAny:
- """
- Class used to return multiple @see cl OnnxVar
- at the same time.
- """
-
- def __init__(self, first, *args):
- if isinstance(first, (list, tuple)):
- raise TypeError( # pragma: no cover
- "Unexpected type for first %r." % type(first))
- if len(args) > 0:
- self.values = (first,) + args
- self.unique = None
- else:
- self.values = None
- self.unique = first
- if self.values is not None and self.unique is not None:
- raise RuntimeError( # pragma: no cover
- "Unexpected configuration. One member (values or unique) must be "
- "null, unique=%r, values=%r" % (self.unique, self.values))
- if self.values is None and self.unique is None:
- raise RuntimeError( # pragma: no cover
- "Unexpected configuration. One member (values or unique) must be "
- "not null.")
-
- def __len__(self):
- "usual"
- if self.values is None:
- raise NotImplementedError( # pragma: no cover
- "Not yet implemented in this case unique=%r, "
- "values=%r." % (self.unique, self.values))
- return len(self.values)
-
- def __iter__(self):
- "Iterates on the outputs."
- if self.values is None:
- raise NotImplementedError( # pragma: no cover
- "Not yet implemented in this case.")
- for v in self.values:
- yield v
-
- def __getitem__(self, i):
- "usual"
- if self.values is None:
- return self.unique[i]
- return self.values[i]
-
- def get_output_type_inference(self, input_shapes=None):
- """
- Returns the expected output types in a list.
- """
- if self.values is None:
- if hasattr(self.unique, 'get_output_type_inference'):
- return self.unique.get_output_type_inference(input_shapes)
- raise NotImplementedError( # pragma: no cover
- "Not implemented yet unique=%r values=%r." % (
- self.unique, self.values))
-
- @property
- def outputs(self):
- "Returns 'output_names' of attribute 'unique'."
- if self.values is None:
- if hasattr(self.unique, 'to_onnx'):
- return self.unique.outputs
- raise NotImplementedError( # pragma: no cover
- "Not implemented yet unique=%r values=%r." % (
- self.unique, self.values))
-
- @property
- def output_names(self):
- "Returns 'output_names' of attribute 'unique'."
- if self.values is None:
- if hasattr(self.unique, 'to_onnx'):
- return self.unique.output_names
- raise NotImplementedError( # pragma: no cover
- "Not implemented yet unique=%r values=%r." % (
- self.unique, self.values))
-
- @output_names.setter
- def output_names(self, value):
- """
- Updates 'output_names' of attribute 'unique'
- or every output name of attribute 'values'.
- """
- if self.values is None:
- if (hasattr(self.unique, 'to_onnx') or
- hasattr(self.unique, 'add_to')):
- if len(value) > 1:
- self.values = tuple(
- OnnxIdentity(self.unique[i], output_names=value[i:i + 1],
- op_version=self.unique.op_version)
- for i in range(0, len(value)))
- self.unique = None
- return
- self.unique.output_names = value
- return
- raise NotImplementedError( # pragma: no cover
- "Not implemented yet, value=%r, unique=%r values=%r." % (
- value, self.unique, self.values))
- if self.values is not None and len(self.values) == len(value):
- for name, v in zip(value, self.values):
- v.output_names = [name]
- return
- raise NotImplementedError( # pragma: no cover
- "Not implemented yet, value=%r, unique=%r values=%r." % (
- value, self.unique, self.values))
-
- def add_to(self, scope, container, operator=None, run_converters=False):
- """
- Adds outputs to the container if not already added,
- registered the outputs if the node is not final.
-
- :param scope: scope
- :param container: container
- :param operator: overwrite inputs
- :param run_converters: must be True if called from method `to_onnx`
- """
- if self.values is not None:
- for v in self.values:
- v.add_to(scope, container, operator=operator,
- run_converters=run_converters)
- return
- if self.unique is not None:
- self.unique.add_to(scope, container, operator=operator,
- run_converters=run_converters)
- return
- raise RuntimeError( # pragma: no cover
- "Attributes 'unique' and 'values' cannot be both null.")
-
- def to_onnx(self, *args, **kwargs): # pylint: disable=W0222
- "Converts the underlying class into an ONNX graph."
- if self.values is None:
- if hasattr(self.unique, 'to_onnx'):
- return self.unique.to_onnx(*args, **kwargs)
- raise NotImplementedError( # pragma: no cover
- "Not implemented yet unique=%r values=%r args=%r "
- "kwargs=%r." % (self.unique, self.values, args, kwargs))
- if self.values is not None:
- if len(self.values) == len(kwargs.get('outputs', [])):
- return self.values[0].to_onnx(
- *args, other_outputs=self.values[1:], **kwargs)
- raise NotImplementedError( # pragma: no cover
- "Not implemented yet unique=%r values=%r args=%r "
- "kwargs=%r." % (self.unique, self.values, args, kwargs))
-
-
-class MultiOnnxVar:
- """
- Class used to return multiple @see cl OnnxVar
- at the same time.
- """
-
- def __init__(self, *inputs, op=None, dtype=None, **kwargs):
- "constructor"
- self.onxvar = OnnxVar(*inputs, op=op, dtype=None, **kwargs)
- self.alg_ = None
-
- def _guess_dtype(self, dtype):
- "Guesses dtype when not specified."
- return self.onxvar._guess_dtype(dtype)
-
- @property
- def inputs(self):
- "Returns `self.onxvar.inputs`."
- return self.onxvar.inputs
-
- @property
- def onnx_op(self):
- "Returns `self.onxvar.onnx_op`."
- return self.onxvar.onnx_op
-
- @property
- def onnx_op_kwargs(self):
- "Returns `self.onxvar.onnx_op_kwargs`."
- return self.onxvar.onnx_op_kwargs
-
- def to_algebra(self, op_version=None):
- """
- Converts the variable into an operator.
- """
- if self.alg_ is None:
- new_inputs = []
- for inp in self.inputs:
- if isinstance(inp, (
- int, float, str, numpy.ndarray, numpy.int32,
- numpy.int64, numpy.float32, numpy.float64,
- numpy_bool, numpy_str, numpy.int8, numpy.uint8,
- numpy.int16, numpy.uint16, numpy.uint32, numpy.uint64)):
- new_inputs.append(inp)
- elif hasattr(inp, 'fit'):
- # scikit-learn models
- new_inputs.append(inp)
- else:
- new_inputs.append(
- inp.to_algebra(op_version=op_version))
-
- if self.onnx_op is None:
- if len(new_inputs) == 1:
- self.alg_ = TupleOnnxAny(new_inputs[0])
- else:
- self.alg_ = TupleOnnxAny(new_inputs[0], *(new_inputs[1:]))
- else:
- res = self.onnx_op( # pylint: disable=E1102
- *new_inputs, op_version=op_version, **self.onnx_op_kwargs)
- self.alg_ = TupleOnnxAny(res)
- return self.alg_
-
- def __getitem__(self, index):
- """
- Returns the ith elements.
- """
- return OnnxVar(self, index=index, op=OnnxOperatorItem)
+"""
+@file
+@brief Intermediate class between :epkg:`numpy` and :epkg:`onnx`.
+
+.. versionadded:: 0.6
+"""
+import numpy
+from onnx.helper import make_tensor
+from skl2onnx.common.data_types import guess_numpy_type
+from skl2onnx.common._topology import Variable # pylint: disable=E0611,E0001
+from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
+ OnnxAdd, OnnxAnd,
+ OnnxCast, OnnxConcat, OnnxConstantOfShape,
+ OnnxDiv,
+ OnnxEqual,
+ OnnxFlatten,
+ OnnxGather, OnnxGreater, OnnxGreaterOrEqual,
+ OnnxIdentity,
+ OnnxLess, OnnxLessOrEqual,
+ OnnxMatMul, OnnxMod, OnnxMul,
+ OnnxNeg, OnnxNot,
+ OnnxOr,
+ OnnxPow,
+ OnnxReduceSum, OnnxReshape,
+ OnnxScatterElements, OnnxShape, OnnxSize, OnnxSlice,
+ OnnxSqueeze, OnnxSub,
+ OnnxTopK, OnnxTranspose,
+ OnnxUnsqueeze,
+ OnnxWhere)
+from skl2onnx.algebra.onnx_operator import OnnxOperatorItem
+from skl2onnx.common.data_types import _guess_numpy_type
+from ..onnx_tools.onnx2py_helper import guess_proto_dtype
+
+
+try:
+ numpy_bool = numpy.bool_
+except AttributeError: # pragma: no cover
+ numpy_bool = bool
+try:
+ numpy_str = numpy.str_
+except AttributeError: # pragma: no cover
+ numpy_str = str
+
+
+class OnnxVar:
+ """
+ Variables used into :epkg:`onnx` computation.
+
+ :param inputs: variable name or object
+ :param op: :epkg:`ONNX` operator
+ :param select_output: if multiple output are returned by
+ ONNX operator *op*, it takes only one specifed by this
+ argument
+ :param dtype: specifies the type of the variable
+ held by this class (*op* is None) in that case
+ :param kwargs: addition argument to give operator *op*
+
+ .. versionadded:: 0.6
+ """
+
+ def __init__(self, *inputs, op=None, select_output=None,
+ dtype=None, **kwargs):
+ self.inputs = inputs
+ self.select_output = select_output
+ self.onnx_op = op
+ self.alg_ = None
+ self.onnx_op_kwargs = kwargs
+ if dtype is not None and (op is not None or len(inputs) != 1):
+ raise RuntimeError( # pragma: no cover
+ "dtype can only be used if op is None or len(inputs) == 1.")
+ for i, inp in enumerate(self.inputs):
+ if isinstance(inp, type):
+ raise TypeError( # pragma: no cover
+ "Unexpected type for input %d - %r." % (i, inp))
+ self.dtype = self._guess_dtype(dtype)
+
+ def _guess_dtype(self, dtype):
+ "Guesses dtype when not specified."
+ if dtype is not None:
+ return dtype
+ dtypes = []
+ for i, inp in enumerate(self.inputs):
+ if isinstance(inp, str):
+ return None
+ if isinstance(inp, numpy.ndarray):
+ dtypes.append(inp.dtype)
+ elif isinstance(inp, Variable):
+ dt = guess_numpy_type(inp.type)
+ dtypes.append(dt)
+ elif isinstance(inp, OnnxVar):
+ dtypes.append(inp.dtype)
+ elif isinstance(inp, MultiOnnxVar):
+ dtypes.append(inp._guess_dtype(dtype))
+ elif isinstance(inp, (numpy.float32, numpy.float64, numpy.int32,
+ numpy.int64)):
+ dtypes.append(inp.dtype)
+ elif isinstance(inp, numpy_str):
+ dtypes.append(numpy_str)
+ elif isinstance(inp, numpy_bool):
+ dtypes.append(numpy_bool)
+ elif isinstance(inp, int):
+ dtypes.append(numpy.int64) # pragma: no cover
+ elif isinstance(inp, float):
+ dtypes.append(numpy.float64)
+ elif hasattr(inp, 'fit'):
+ # scikit-learn model
+ continue
+ else:
+ raise TypeError( # pragma: no cover
+ "Unexpected type for input %i type=%r." % (i, type(inp)))
+ dtypes = [_ for _ in dtypes if _ is not None]
+ unique = set(dtypes)
+ if len(unique) != 1:
+ return None
+ return dtypes[0]
+
+ def __repr__(self):
+ "usual"
+ args = []
+ for inp in self.inputs:
+ args.append(repr(inp))
+ if self.onnx_op is not None:
+ if isinstance(self.onnx_op, str):
+ args.append("op=%r" % self.onnx_op)
+ else:
+ args.append("op=%s" % self.onnx_op.__name__)
+ if self.select_output is not None:
+ args.append("select_output=%r" % self.select_output)
+ if self.dtype is not None and self.dtype != self._guess_dtype(None):
+ args.append("dtype=%r" % self.dtype)
+ for k, v in sorted(self.onnx_op_kwargs.items()):
+ args.append("%s=%r" % (k, v))
+ res = "%s(%s)" % (self.__class__.__name__, ", ".join(args))
+ return res
+
+ def to_algebra(self, op_version=None):
+ """
+ Converts the variable into an operator.
+ """
+ if self.alg_ is None:
+ if self.onnx_op is None:
+ if len(self.inputs) != 1:
+ raise RuntimeError( # pragma: no cover
+ "Unexpected number of inputs, 1 expected, "
+ "got {} instead.".format(self.inputs))
+ if self.dtype is None or hasattr(self.inputs[0], 'onnx_name'):
+ self.alg_ = self.inputs[0]
+ else:
+ self.alg_ = (
+ self.inputs[0], _guess_numpy_type(self.dtype, None))
+ else:
+ if isinstance(self.onnx_op, str):
+ var = self._custom_op(*self.inputs, op_version=op_version,
+ **self.onnx_op_kwargs)
+ alg = var.to_algebra(op_version=op_version)
+ if not hasattr(self, 'alg_'):
+ raise RuntimeError( # pragma: no cover
+ "Missing attribute 'alg_'.")
+ self.alg_ = alg
+ return alg
+
+ new_inputs = []
+ for inp in self.inputs:
+ if hasattr(inp, 'fit'):
+ # scikit-learn model
+ new_inputs.append(inp)
+ elif isinstance(inp, (
+ int, float, str, numpy.ndarray, numpy.int32,
+ numpy.int64, numpy.float32, numpy.float64,
+ numpy_bool, numpy_str, numpy.int8, numpy.uint8,
+ numpy.int16, numpy.uint16, numpy.uint32, numpy.uint64)):
+ new_inputs.append(inp)
+ else:
+ new_inputs.append(
+ inp.to_algebra(op_version=op_version))
+
+ res = self.onnx_op(*new_inputs, op_version=op_version,
+ **self.onnx_op_kwargs)
+ if self.select_output is None:
+ self.alg_ = res
+ else:
+ self.alg_ = res[self.select_output]
+ return self.alg_
+
+ def _custom_op(self, *args, op_version=None, runtime=None, **kwargs):
+ """
+ This could be handled before a call to this method
+ but this method can change the conversion of an non-existing
+ operator depending on the given opset.
+ """
+ if self.onnx_op == 'filter':
+ return self._custom_op_filter(*args, op_version=op_version,
+ runtime=runtime, **kwargs)
+ raise NotImplementedError( # pragma: no cover
+ "Unexpected custom operator %r." % self.onnx_op)
+
+ def _custom_op_filter(self, *args, op_version=None, runtime=None, **kwargs):
+ """
+ This could be handled before a call to this method
+ but this method can change the conversion of an non-existing
+ operator depending on the given opset.
+ """
+ if len(args) != 2:
+ raise RuntimeError( # pragma: no cover
+ "Custom op 'filter' expects two inputs not %r." % len(args))
+ if len(kwargs) != 0:
+ raise RuntimeError( # pragma: no cover
+ "Custom op 'filter' expects no arguments but got %r." % kwargs)
+ mat, index = args
+ cast = OnnxVar(index.astype(numpy.int64), op=OnnxSqueeze)
+ n1 = OnnxVar(cast, op=OnnxReduceSum, keepdims=1)
+ indices = OnnxVar(cast, n1, op=OnnxTopK, select_output=1)
+ return OnnxVar(mat, indices, op=OnnxGather)
+
+ @property
+ def T(self):
+ "Transpose."
+ return OnnxVar(self, op=OnnxTranspose)
+
+ def astype(self, dtype):
+ "Cast"
+ return OnnxVar(self, op=OnnxCast, to=guess_proto_dtype(dtype))
+
+ @property
+ def shape(self):
+ "Shape"
+ return OnnxVar(self, op=OnnxShape)
+
+ @property
+ def size(self):
+ "Size"
+ return OnnxVar(self, op=OnnxSize)
+
+ def reshape(self, shape):
+ "Reshape"
+ if isinstance(shape, (tuple, list)):
+ shape = numpy.array(shape, dtype=numpy.int64)
+ return OnnxVar(self, shape, op=OnnxReshape)
+
+ def _make_array(self, y):
+ """Converts *y* into an array if not."""
+ if hasattr(y, 'dtype') and not isinstance(y, (numpy.ndarray, OnnxVar)):
+ return numpy.full((1, ), y, dtype=y.dtype)
+ if isinstance(y, (float, int, str)):
+ return numpy.array([y])
+ return y
+
+ def __add__(self, y):
+ "Addition."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxAdd)
+
+ def __sub__(self, y):
+ "Subtraction."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxSub)
+
+ def __mul__(self, y):
+ "Multiplication."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxMul)
+
+ def __pow__(self, y):
+ "Power."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxPow)
+
+ def __mod__(self, y):
+ "Modulo."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxMod)
+
+ def __matmul__(self, y):
+ "Matrix multiplication."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxMatMul)
+
+ def __truediv__(self, y):
+ "Division, no difference between `/` and `//`."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxDiv)
+
+ def __floordiv__(self, y):
+ "Division, no difference between `/` and `//`."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxDiv)
+
+ def __eq__(self, y):
+ "Equality."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxEqual)
+
+ def __ne__(self, y):
+ "Difference."
+ y = self._make_array(y)
+ return OnnxVar(OnnxVar(self, y, op=OnnxEqual), op=OnnxNot)
+
+ def __ge__(self, y):
+ "Greater or Equal."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxGreaterOrEqual)
+
+ def __gt__(self, y):
+ "Greater."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxGreater)
+
+ def __le__(self, y):
+ "Less or Equal."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxLessOrEqual)
+
+ def __lt__(self, y):
+ "Less."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxLess)
+
+ def __and__(self, y):
+ "And."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxAnd)
+
+ def __or__(self, y):
+ "And."
+ y = self._make_array(y)
+ return OnnxVar(self, y, op=OnnxOr)
+
+ def not_(self):
+ "Not."
+ return OnnxVar(self, op=OnnxNot)
+
+ def __neg__(self):
+ "Neg."
+ return OnnxVar(self, op=OnnxNeg)
+
+ def __getitem__(self, index):
+ """
+ Deals with multiple scenarios.
+ * *index* is an integer or a slice, a tuple of integers and slices,
+ example: `[0, 1]`, `[:5, :6]`, `[::2]` (**scenario 1**)
+ * *index* is an *ONNX* object (more precisely an instance of
+ @see cl OnnxVar), then the method assumes it is an array of
+ boolean to select a subset of the tensor along the first axis,
+ example: `mat[mat == 0]` (**scenario 2**)
+ """
+ if isinstance(index, OnnxVar):
+ # scenario 2
+ return OnnxVar(self, index, op='filter')
+
+ if isinstance(index, int):
+ # Use Gather instead.
+ return OnnxVar(
+ self, numpy.array(index, dtype=numpy.int64),
+ axis=0, op=OnnxGather)
+
+ if not isinstance(index, tuple):
+ index = (index, )
+
+ # only one integer?
+ ni = None
+ ax = None
+ for i, a in enumerate(index):
+ if isinstance(a, int):
+ if ni is None:
+ ni = i
+ ax = a
+ else:
+ ax = None
+ ni = None
+ break
+ if (isinstance(a, slice) and a.start is None and
+ a.stop is None and a.step is None):
+ continue
+ ax = None
+ ni = None
+ break
+ if ni is not None and ax is not None:
+ # Use Gather instead.
+ return OnnxVar(
+ self, numpy.array(ni, dtype=numpy.int64),
+ axis=ax, op=OnnxGather)
+
+ # scenario 1
+ starts = []
+ ends = []
+ axes = []
+ steps = []
+ axis_squeeze = []
+ needs_shape = []
+ for i, ind in enumerate(index):
+ if isinstance(ind, int):
+ starts.append(ind)
+ ends.append(ind + 1)
+ axes.append(i)
+ steps.append(1)
+ axis_squeeze.append(i)
+ continue
+ if isinstance(ind, slice):
+ if ind.start is None and ind.stop is None and ind.step is None:
+ continue
+ start = 0 if ind.start is None else ind.start
+ end = (None, i) if ind.stop is None else ind.stop
+ step = 1 if ind.step is None else ind.step
+ starts.append(start)
+ ends.append(end)
+ axes.append(i)
+ steps.append(step)
+ if isinstance(end, tuple):
+ needs_shape.append(len(ends) - 1)
+ continue
+ raise NotImplementedError( # pragma: no cover
+ "Not implemented for type %r." % type(ind))
+
+ if max(steps) == min(steps) == 1:
+ steps = None
+ else:
+ steps = numpy.array(steps, dtype=numpy.int64)
+
+ starts = numpy.array(starts, dtype=numpy.int64)
+ axes = numpy.array(axes, dtype=numpy.int64)
+
+ if len(needs_shape) > 0:
+ shape = self.shape
+ conc = []
+ for e in ends:
+ if isinstance(e, tuple):
+ conc.append(
+ OnnxVar(shape[e[1]],
+ numpy.array([0], dtype=numpy.int64),
+ op=OnnxUnsqueeze))
+ else:
+ conc.append(numpy.array([e], dtype=numpy.int64))
+ ends = OnnxVar(*conc, op=OnnxConcat, axis=0)
+ else:
+ ends = numpy.array(ends, dtype=numpy.int64)
+ if steps is None:
+ sliced = OnnxVar(self, starts, ends, axes, op=OnnxSlice)
+ else:
+ sliced = OnnxVar(self, starts, ends, axes, steps, op=OnnxSlice)
+ if len(axis_squeeze) > 0:
+ return OnnxVar(
+ sliced, numpy.array(axis_squeeze, dtype=numpy.int64),
+ op=OnnxSqueeze)
+ return sliced
+
+ def __setitem__(self, index, value):
+ """
+ Only supports vectors (1D tensor).
+ * *index* is an integer or a slice, a tuple of integers and slices,
+ example: `[0]`, `[:5]`, `[::2]` (**scenario 1**)
+ * *index* is an *ONNX* object (more precisely an instance of
+ @see cl OnnxVar), then the method assumes it is an array of
+ boolean to select a subset of the tensor along the first axis,
+ example: `mat[mat == 0]` (**scenario 2**)
+ This processing is applied before the operator it contains.
+ A copy should be made (Identity node or copy method).
+ """
+ if self.onnx_op is not None and self.onnx_op is not OnnxIdentity:
+ raise RuntimeError( # pragma: no cover
+ "A copy should be made before setting new values on a matrix. "
+ "Method copy() would do that.")
+
+ if isinstance(index, OnnxVar):
+ # scenario 2, example: cp[x < 0] = -1
+ return self._setitem2i_(index, value)
+ elif not isinstance(index, tuple):
+ index = (index, )
+
+ for i in index:
+ if isinstance(i, OnnxVar):
+ raise NotImplementedError( # pragma: no cover
+ "Unable to handle case such as cp[0, x < 0] = -1.")
+
+ # scenario 1
+ if len(index) == 1:
+ return self._setitem1i_(index[0], value)
+ raise NotImplementedError( # pragma: no cover
+ "Indices in %d dimensions are not implemented yet." % len(index))
+
+ def _setitem1i_(self, index, value):
+ sl = None
+ if isinstance(index, slice):
+ start = 0 if index.start is None else index.start
+ stop = index.stop
+ step = index.step
+ sl = [start, stop, step]
+ elif isinstance(index, int):
+ sl = [index, index + 1, 1]
+ else:
+ raise NotImplementedError( # pragma: no cover
+ "Unable to assign new values due to unexpected type %r."
+ "" % type(index))
+
+ if sl[1] is None and isinstance(value, numpy.ndarray):
+ sl[1] = sl[0] + value.size
+ if sl[1] is None:
+ if sl[2] is not None and sl[2] != 1:
+ raise NotImplementedError( # pragma: no cover
+ "If the length is not known, step must be 1 not %d." % sl[2])
+ value = make_tensor(
+ "value", guess_proto_dtype(value.dtype), (1, ), [value]) # pylint: disable=E1101
+ inp = self.inputs[0]
+ if not isinstance(inp, OnnxVar):
+ raise RuntimeError( # pragma: no cover
+ "Input must be an instance of OnnxVar not %r." % type(inp))
+ cst = OnnxVar(inp.shape, op=OnnxConstantOfShape, value=value)
+ ext = inp[:sl[0]]
+ indices = numpy.arange(0, sl[0]).astype(numpy.int64)
+ add_step = OnnxVar(cst, indices, ext,
+ op=OnnxScatterElements, axis=0)
+ else:
+ indices = numpy.arange(sl[0], sl[1], sl[2]).astype(numpy.int64)
+ if isinstance(value, numpy.ndarray):
+ values = value
+ else:
+ values = numpy.full(indices.shape, value)
+ add_step = OnnxVar(self.inputs[0], indices, values,
+ op=OnnxScatterElements, axis=0)
+
+ self.inputs = [add_step]
+ return self
+
+ def _setitem2i_(self, index, value):
+ add_step = OnnxVar(index, value, self.inputs[0], op=OnnxWhere)
+ self.inputs = [add_step]
+ return self
+
+ def copy(self):
+ """
+ Returns a copy of self (use of Identity node).
+ """
+ return OnnxVar(self, op=OnnxIdentity)
+
+ def flatten(self, axis=0):
+ """
+ Flattens a matrix (see :epkg:`numpy:ndarray:flatten`).
+
+ :param axis: only flatten from axis to the end.
+ :return: @see cl OnnxVariable
+ """
+ fl = OnnxVar(self, op=OnnxFlatten, axis=axis)
+ if axis == 0:
+ return OnnxVar(fl, numpy.array([0], dtype=numpy.int64),
+ op=OnnxSqueeze)
+ return fl
+
+
+class TupleOnnxAny:
+ """
+ Class used to return multiple @see cl OnnxVar
+ at the same time.
+ """
+
+ def __init__(self, first, *args):
+ if isinstance(first, (list, tuple)):
+ raise TypeError( # pragma: no cover
+ "Unexpected type for first %r." % type(first))
+ if len(args) > 0:
+ self.values = (first,) + args
+ self.unique = None
+ else:
+ self.values = None
+ self.unique = first
+ if self.values is not None and self.unique is not None:
+ raise RuntimeError( # pragma: no cover
+ "Unexpected configuration. One member (values or unique) must be "
+ "null, unique=%r, values=%r" % (self.unique, self.values))
+ if self.values is None and self.unique is None:
+ raise RuntimeError( # pragma: no cover
+ "Unexpected configuration. One member (values or unique) must be "
+ "not null.")
+
+ def __len__(self):
+ "usual"
+ if self.values is None:
+ raise NotImplementedError( # pragma: no cover
+ "Not yet implemented in this case unique=%r, "
+ "values=%r." % (self.unique, self.values))
+ return len(self.values)
+
+ def __iter__(self):
+ "Iterates on the outputs."
+ if self.values is None:
+ raise NotImplementedError( # pragma: no cover
+ "Not yet implemented in this case.")
+ for v in self.values:
+ yield v
+
+ def __getitem__(self, i):
+ "usual"
+ if self.values is None:
+ return self.unique[i]
+ return self.values[i]
+
+ def get_output_type_inference(self, input_shapes=None):
+ """
+ Returns the expected output types in a list.
+ """
+ if self.values is None:
+ if hasattr(self.unique, 'get_output_type_inference'):
+ return self.unique.get_output_type_inference(input_shapes)
+ raise NotImplementedError( # pragma: no cover
+ "Not implemented yet unique=%r values=%r." % (
+ self.unique, self.values))
+
+ @property
+ def outputs(self):
+ "Returns 'output_names' of attribute 'unique'."
+ if self.values is None:
+ if hasattr(self.unique, 'to_onnx'):
+ return self.unique.outputs
+ raise NotImplementedError( # pragma: no cover
+ "Not implemented yet unique=%r values=%r." % (
+ self.unique, self.values))
+
+ @property
+ def output_names(self):
+ "Returns 'output_names' of attribute 'unique'."
+ if self.values is None:
+ if hasattr(self.unique, 'to_onnx'):
+ return self.unique.output_names
+ raise NotImplementedError( # pragma: no cover
+ "Not implemented yet unique=%r values=%r." % (
+ self.unique, self.values))
+
+ @output_names.setter
+ def output_names(self, value):
+ """
+ Updates 'output_names' of attribute 'unique'
+ or every output name of attribute 'values'.
+ """
+ if self.values is None:
+ if (hasattr(self.unique, 'to_onnx') or
+ hasattr(self.unique, 'add_to')):
+ if len(value) > 1:
+ self.values = tuple(
+ OnnxIdentity(self.unique[i], output_names=value[i:i + 1],
+ op_version=self.unique.op_version)
+ for i in range(0, len(value)))
+ self.unique = None
+ return
+ self.unique.output_names = value
+ return
+ raise NotImplementedError( # pragma: no cover
+ "Not implemented yet, value=%r, unique=%r values=%r." % (
+ value, self.unique, self.values))
+ if self.values is not None and len(self.values) == len(value):
+ for name, v in zip(value, self.values):
+ v.output_names = [name]
+ return
+ raise NotImplementedError( # pragma: no cover
+ "Not implemented yet, value=%r, unique=%r values=%r." % (
+ value, self.unique, self.values))
+
+ def add_to(self, scope, container, operator=None, run_converters=False):
+ """
+ Adds outputs to the container if not already added,
+ registered the outputs if the node is not final.
+
+ :param scope: scope
+ :param container: container
+ :param operator: overwrite inputs
+ :param run_converters: must be True if called from method `to_onnx`
+ """
+ if self.values is not None:
+ for v in self.values:
+ v.add_to(scope, container, operator=operator,
+ run_converters=run_converters)
+ return
+ if self.unique is not None:
+ self.unique.add_to(scope, container, operator=operator,
+ run_converters=run_converters)
+ return
+ raise RuntimeError( # pragma: no cover
+ "Attributes 'unique' and 'values' cannot be both null.")
+
+ def to_onnx(self, *args, **kwargs): # pylint: disable=W0222
+ "Converts the underlying class into an ONNX graph."
+ if self.values is None:
+ if hasattr(self.unique, 'to_onnx'):
+ return self.unique.to_onnx(*args, **kwargs)
+ raise NotImplementedError( # pragma: no cover
+ "Not implemented yet unique=%r values=%r args=%r "
+ "kwargs=%r." % (self.unique, self.values, args, kwargs))
+ if self.values is not None:
+ if len(self.values) == len(kwargs.get('outputs', [])):
+ return self.values[0].to_onnx(
+ *args, other_outputs=self.values[1:], **kwargs)
+ raise NotImplementedError( # pragma: no cover
+ "Not implemented yet unique=%r values=%r args=%r "
+ "kwargs=%r." % (self.unique, self.values, args, kwargs))
+
+
+class MultiOnnxVar:
+ """
+ Class used to return multiple @see cl OnnxVar
+ at the same time.
+ """
+
+ def __init__(self, *inputs, op=None, dtype=None, **kwargs):
+ "constructor"
+ self.onxvar = OnnxVar(*inputs, op=op, dtype=None, **kwargs)
+ self.alg_ = None
+
+ def _guess_dtype(self, dtype):
+ "Guesses dtype when not specified."
+ return self.onxvar._guess_dtype(dtype)
+
+ @property
+ def inputs(self):
+ "Returns `self.onxvar.inputs`."
+ return self.onxvar.inputs
+
+ @property
+ def onnx_op(self):
+ "Returns `self.onxvar.onnx_op`."
+ return self.onxvar.onnx_op
+
+ @property
+ def onnx_op_kwargs(self):
+ "Returns `self.onxvar.onnx_op_kwargs`."
+ return self.onxvar.onnx_op_kwargs
+
+ def to_algebra(self, op_version=None):
+ """
+ Converts the variable into an operator.
+ """
+ if self.alg_ is None:
+ new_inputs = []
+ for inp in self.inputs:
+ if isinstance(inp, (
+ int, float, str, numpy.ndarray, numpy.int32,
+ numpy.int64, numpy.float32, numpy.float64,
+ numpy_bool, numpy_str, numpy.int8, numpy.uint8,
+ numpy.int16, numpy.uint16, numpy.uint32, numpy.uint64)):
+ new_inputs.append(inp)
+ elif hasattr(inp, 'fit'):
+ # scikit-learn models
+ new_inputs.append(inp)
+ else:
+ new_inputs.append(
+ inp.to_algebra(op_version=op_version))
+
+ if self.onnx_op is None:
+ if len(new_inputs) == 1:
+ self.alg_ = TupleOnnxAny(new_inputs[0])
+ else:
+ self.alg_ = TupleOnnxAny(new_inputs[0], *(new_inputs[1:]))
+ else:
+ res = self.onnx_op( # pylint: disable=E1102
+ *new_inputs, op_version=op_version, **self.onnx_op_kwargs)
+ self.alg_ = TupleOnnxAny(res)
+ return self.alg_
+
+ def __getitem__(self, index):
+ """
+ Returns the ith elements.
+ """
+ return OnnxVar(self, index=index, op=OnnxOperatorItem)
diff --git a/mlprodict/onnxrt/ops_cpu/_op_numpy_helper.py b/mlprodict/onnxrt/ops_cpu/_op_numpy_helper.py
index 5c1dc116e..e151a798a 100644
--- a/mlprodict/onnxrt/ops_cpu/_op_numpy_helper.py
+++ b/mlprodict/onnxrt/ops_cpu/_op_numpy_helper.py
@@ -62,8 +62,12 @@ def numpy_matmul_inplace(inplaces, a, b):
container modifies the results. This part still needs to be
improves.
"""
- if isinstance(a, coo_matrix) or isinstance(b, coo_matrix):
- return numpy.dot(a, b)
- if len(a.shape) <= 2 and len(b.shape) <= 2:
- return numpy_dot_inplace(inplaces, a, b)
- return numpy.matmul(a, b)
+ try:
+ if isinstance(a, coo_matrix) or isinstance(b, coo_matrix):
+ return numpy.dot(a, b)
+ if len(a.shape) <= 2 and len(b.shape) <= 2:
+ return numpy_dot_inplace(inplaces, a, b)
+ return numpy.matmul(a, b)
+ except ValueError as e:
+ raise ValueError(
+ "Unable to multiply shapes %r, %r." % (a.shape, b.shape)) from e
diff --git a/mlprodict/onnxrt/shape_object.py b/mlprodict/onnxrt/shape_object.py
index 9d319b330..2e7d7e8f8 100644
--- a/mlprodict/onnxrt/shape_object.py
+++ b/mlprodict/onnxrt/shape_object.py
@@ -536,6 +536,11 @@ def __init__(self, shape, dtype=None, use_n1=False, name=None):
sh = self._shape[0] if self._shape else None
if isinstance(sh, DimensionObject) and sh._dim is None:
sh._dim = 'n'
+ if self._shape is not None:
+ for s in self._shape:
+ if isinstance(s, int):
+ raise TypeError( # pragma: no cover
+ "Unexpected type int in shape %r." % self)
def reshape(self, shape):
"""
@@ -867,9 +872,14 @@ def concat_columns(self, axis, *shapes):
"""
args = [self] + list(shapes)
dtype = self._infer_merged_type(*args)
- dim_axis = args[0][axis]
+ dim_axis = self[axis]
+ if isinstance(dim_axis, int):
+ dim_axis = DimensionObject(dim_axis)
if dim_axis is None:
return ShapeObject(None, dtype=dtype)
+ if isinstance(dim_axis, int):
+ raise TypeError( # pragma: no cover
+ "Unexpected type for shape %r." % self)
for a in shapes:
if a[axis] is None:
return ShapeObject(None, dtype=dtype)
diff --git a/requirements.txt b/requirements.txt
index 06ad85855..b8b8b6adc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,7 +12,7 @@ numpy>=1.19.0
pandas
pillow
scikit-learn>=0.24
-scipy
+scipy>=1.7.0
Sphinx
wheel