diff --git a/_doc/notebooks/onnx_fft.ipynb b/_doc/notebooks/onnx_fft.ipynb index 04fc67471..ce4072f68 100644 --- a/_doc/notebooks/onnx_fft.ipynb +++ b/_doc/notebooks/onnx_fft.ipynb @@ -1,1628 +1,1628 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "51bc89fc", - "metadata": {}, - "source": [ - "# ONNX and FFT\n", - "\n", - "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7b2add97", - "metadata": {}, - "outputs": [ + "cells": [ + { + "cell_type": "markdown", + "id": "51bc89fc", + "metadata": {}, + "source": [ + "# ONNX and FFT\n", + "\n", + "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?" + ] + }, { - "data": { - "text/html": [ - "
run previous cell, wait for 2 seconds
\n", - "" + "cell_type": "code", + "execution_count": 1, + "id": "7b2add97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
run previous cell, wait for 2 seconds
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "from jyquickhelper import add_notebook_menu\n", + "add_notebook_menu()" ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jyquickhelper import add_notebook_menu\n", - "add_notebook_menu()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "acfdc3b0", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext mlprodict" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "abb5fa88", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "'1.21.0'" + "cell_type": "code", + "execution_count": 2, + "id": "acfdc3b0", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext mlprodict" ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy\n", - "numpy.__version__" - ] - }, - { - "cell_type": "markdown", - "id": "2e4f68e4", - "metadata": {}, - "source": [ - "## Python implementation of RFFT\n", - "\n", - "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "cb1cc910", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[-1.36418152+0.j , 2.3512614 -1.47772773j,\n", - " -3.4774066 +3.40257902j, -1.73059963+1.64505308j],\n", - " [ 0.6441313 +0.j , -0.87221646-1.87952026j,\n", - " 1.0705215 -1.186307j , 1.31619296+5.60515407j],\n", - " [ 1.38915221+0.j , -1.11980049+2.87742877j,\n", - " -0.25900143+0.17339344j, -1.45116622+1.24798734j],\n", - " [-1.86380783+0.j , 2.37798625+1.72008612j,\n", - " 1.42540207+1.57713781j, 0.18057206+1.32039835j],\n", - " [ 4.1150526 +0.j , -3.35634771-0.41940018j,\n", - " -0.38524887-1.39453991j, -0.31538136+1.7538376j ]])" + "cell_type": "code", + "execution_count": 3, + "id": "abb5fa88", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.21.0'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy\n", + "numpy.__version__" ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy\n", - "\n", - "\n", - "def almost_equal(a, b, error=1e-5):\n", - " \"\"\"\n", - " The function compares two matrices, one may be complex. In that case,\n", - " this matrix is changed into a new matrix with a new first dimension,\n", - " [0,::] means real part, [1,::] means imaginary part.\n", - " \"\"\"\n", - " if a.dtype in (numpy.complex64, numpy.complex128):\n", - " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n", - " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n", - " new_a[0] = numpy.real(a)\n", - " new_a[1] = numpy.imag(a)\n", - " return almost_equal(new_a, b, error)\n", - " if b.dtype in (numpy.complex64, numpy.complex128):\n", - " return almost_equal(b, a, error)\n", - " if a.shape != b.shape:\n", - " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n", - " diff = numpy.abs(a.ravel() - b.ravel()).max()\n", - " if diff > error:\n", - " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n", - "\n", - "\n", - "def dft_real_cst(N, fft_length):\n", - " n = numpy.arange(N)\n", - " k = n.reshape((N, 1)).astype(numpy.float64)\n", - " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", - " both = numpy.empty((2,) + M.shape)\n", - " both[0, :, :] = numpy.real(M)\n", - " both[1, :, :] = numpy.imag(M)\n", - " return both\n", - "\n", - "\n", - "def dft_real(x, fft_length=None, transpose=True):\n", - " if len(x.shape) == 1:\n", - " x = x.reshape((1, -1))\n", - " N = 1\n", - " else:\n", - " N = x.shape[0] \n", - " C = x.shape[-1] if transpose else x.shape[-2]\n", - " if fft_length is None:\n", - " fft_length = x.shape[-1]\n", - " size = fft_length // 2 + 1\n", - "\n", - " cst = dft_real_cst(C, fft_length)\n", - " if transpose:\n", - " x = numpy.transpose(x, (1, 0))\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:fft_length]\n", - " res = numpy.matmul(a, b)\n", - " res = res[:, :size, :]\n", - " return numpy.transpose(res, (0, 2, 1))\n", - " else:\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:fft_length]\n", - " return numpy.matmul(a, b)\n", - "\n", - "\n", - "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", - "fft_np = numpy.fft.rfft(rnd)\n", - "fft_cus = dft_real(rnd)\n", - "fft_np" - ] - }, - { - "cell_type": "markdown", - "id": "0c052ea1", - "metadata": {}, - "source": [ - "Function `almost_equal` verifies both functions return the same results." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "3ca040cb", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_np, fft_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "7fe77440", - "metadata": {}, - "source": [ - "Let's do the same with `fft_length < shape[1]`." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "3a747a4a", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[ 0.50121731+0.j , -1.76725248+1.19033269j],\n", - " [ 1.84486783+0.j , -0.13533521+1.86170961j],\n", - " [-1.49032012+0.j , -0.17000796+0.02887427j],\n", - " [-0.8358376 +0.j , 1.725943 +0.19581766j],\n", - " [ 0.9690519 +0.j , -1.34143379+0.70979425j]])" + "cell_type": "markdown", + "id": "2e4f68e4", + "metadata": {}, + "source": [ + "## Python implementation of RFFT\n", + "\n", + "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)." ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft_np3 = numpy.fft.rfft(rnd, n=3)\n", - "fft_cus3 = dft_real(rnd, fft_length=3)\n", - "fft_np3" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "0db6247b", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_np3, fft_cus3)" - ] - }, - { - "cell_type": "markdown", - "id": "31a6ac9c", - "metadata": {}, - "source": [ - "## RFFT in ONNX\n", - "\n", - "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "efb67b9b", - "metadata": { - "scrolled": false - }, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[[-1.3641814 , 2.3512614 , -3.4774065 , -1.7305996 ],\n", - " [ 0.6441313 , -0.87221646, 1.0705215 , 1.3161929 ],\n", - " [ 1.3891523 , -1.1198004 , -0.25900146, -1.4511662 ],\n", - " [-1.8638077 , 2.3779864 , 1.425402 , 0.18057205],\n", - " [ 4.1150527 , -3.3563478 , -0.38524887, -0.31538135]],\n", - "\n", - " [[ 0. , -1.4777277 , 3.402579 , 1.6450533 ],\n", - " [ 0. , -1.8795203 , -1.1863071 , 5.605154 ],\n", - " [ 0. , 2.8774288 , 0.17339343, 1.2479873 ],\n", - " [ 0. , 1.7200862 , 1.5771378 , 1.3203983 ],\n", - " [ 0. , -0.41940016, -1.39454 , 1.7538376 ]]],\n", - " dtype=float32)" + "cell_type": "code", + "execution_count": 4, + "id": "cb1cc910", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-1.36418152+0.j , 2.3512614 -1.47772773j,\n", + " -3.4774066 +3.40257902j, -1.73059963+1.64505308j],\n", + " [ 0.6441313 +0.j , -0.87221646-1.87952026j,\n", + " 1.0705215 -1.186307j , 1.31619296+5.60515407j],\n", + " [ 1.38915221+0.j , -1.11980049+2.87742877j,\n", + " -0.25900143+0.17339344j, -1.45116622+1.24798734j],\n", + " [-1.86380783+0.j , 2.37798625+1.72008612j,\n", + " 1.42540207+1.57713781j, 0.18057206+1.32039835j],\n", + " [ 4.1150526 +0.j , -3.35634771-0.41940018j,\n", + " -0.38524887-1.39453991j, -0.31538136+1.7538376j ]])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy\n", + "\n", + "\n", + "def almost_equal(a, b, error=1e-5):\n", + " \"\"\"\n", + " The function compares two matrices, one may be complex. In that case,\n", + " this matrix is changed into a new matrix with a new first dimension,\n", + " [0,::] means real part, [1,::] means imaginary part.\n", + " \"\"\"\n", + " if a.dtype in (numpy.complex64, numpy.complex128):\n", + " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n", + " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n", + " new_a[0] = numpy.real(a)\n", + " new_a[1] = numpy.imag(a)\n", + " return almost_equal(new_a, b, error)\n", + " if b.dtype in (numpy.complex64, numpy.complex128):\n", + " return almost_equal(b, a, error)\n", + " if a.shape != b.shape:\n", + " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n", + " diff = numpy.abs(a.ravel() - b.ravel()).max()\n", + " if diff > error:\n", + " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n", + "\n", + "\n", + "def dft_real_cst(N, fft_length):\n", + " n = numpy.arange(N)\n", + " k = n.reshape((N, 1)).astype(numpy.float64)\n", + " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", + " both = numpy.empty((2,) + M.shape)\n", + " both[0, :, :] = numpy.real(M)\n", + " both[1, :, :] = numpy.imag(M)\n", + " return both\n", + "\n", + "\n", + "def dft_real(x, fft_length=None, transpose=True):\n", + " if len(x.shape) == 1:\n", + " x = x.reshape((1, -1))\n", + " N = 1\n", + " else:\n", + " N = x.shape[0] \n", + " C = x.shape[-1] if transpose else x.shape[-2]\n", + " if fft_length is None:\n", + " fft_length = x.shape[-1]\n", + " size = fft_length // 2 + 1\n", + "\n", + " cst = dft_real_cst(C, fft_length)\n", + " if transpose:\n", + " x = numpy.transpose(x, (1, 0))\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:fft_length]\n", + " res = numpy.matmul(a, b)\n", + " res = res[:, :size, :]\n", + " return numpy.transpose(res, (0, 2, 1))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:fft_length]\n", + " return numpy.matmul(a, b)\n", + "\n", + "\n", + "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", + "fft_np = numpy.fft.rfft(rnd)\n", + "fft_cus = dft_real(rnd)\n", + "fft_np" ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from typing import Any\n", - "import mlprodict.npy.numpy_onnx_impl as npnx\n", - "from mlprodict.npy import onnxnumpy_np\n", - "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n", - "# from mlprodict.onnxrt import OnnxInference\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft(x, fft_length=None):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " xt = npnx.transpose(x, (1, 0))\n", - " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", - " return npnx.transpose(res, (0, 2, 1))\n", - "\n", - "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n", - "fft_onx" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c4b6b1a5", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft_cus, fft_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "a8c35327", - "metadata": {}, - "source": [ - "The corresponding ONNX graph is the following:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "4d1a85b0", - "metadata": {}, - "outputs": [ + }, + { + "cell_type": "markdown", + "id": "0c052ea1", + "metadata": {}, + "source": [ + "Function `almost_equal` verifies both functions return the same results." + ] + }, { - "data": { - "text/html": [ - "
\n", - "" + "cell_type": "code", + "execution_count": 5, + "id": "3ca040cb", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_np, fft_cus)" + ] + }, + { + "cell_type": "markdown", + "id": "7fe77440", + "metadata": {}, + "source": [ + "Let's do the same with `fft_length < shape[1]`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3a747a4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 0.50121731+0.j , -1.76725248+1.19033269j],\n", + " [ 1.84486783+0.j , -0.13533521+1.86170961j],\n", + " [-1.49032012+0.j , -0.17000796+0.02887427j],\n", + " [-0.8358376 +0.j , 1.725943 +0.19581766j],\n", + " [ 0.9690519 +0.j , -1.34143379+0.70979425j]])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "fft_np3 = numpy.fft.rfft(rnd, n=3)\n", + "fft_cus3 = dft_real(rnd, fft_length=3)\n", + "fft_np3" ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft.signed_compiled)[0]\n", - "%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "6cf18aca", - "metadata": {}, - "outputs": [], - "source": [ - "fft_onx3 = onnx_rfft(rnd, fft_length=3)\n", - "almost_equal(fft_cus3, fft_onx3)" - ] - }, - { - "cell_type": "markdown", - "id": "6b466fd4", - "metadata": {}, - "source": [ - "## FFT 2D\n", - "\n", - "Below the code for complex features." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e0020084", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[ 5.56511808+0.j , 2.40434541-6.58876113j,\n", - " -2.99787318+6.09018702j, 2.95547828-4.78324036j],\n", - " [-1.46413093-9.60314657j, -1.66675694+0.60494258j,\n", - " -2.49781725+0.06244585j, -4.22491665-2.62906297j],\n", - " [-1.06488187-4.9975721j , -7.53624925+2.11727626j,\n", - " -2.93212515+2.35814643j, -1.32906648-6.29456206j],\n", - " [-1.06488187+4.9975721j , 1.28429404-3.79395468j,\n", - " -4.70865157+7.14245256j, 4.69409373-0.6235566j ],\n", - " [-1.46413093+9.60314657j, -3.6701756 -3.08301071j,\n", - " -0.67943963-5.20582724j, 5.05462128+3.35367375j]])" + "cell_type": "code", + "execution_count": 7, + "id": "0db6247b", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_np3, fft_cus3)" ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def _DFT_cst(N, fft_length, trunc=True):\n", - " n = numpy.arange(N)\n", - " k = n.reshape((N, 1)).astype(numpy.float64)\n", - " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", - " return M[:fft_length // 2 + 1] if trunc else M\n", - "\n", - "def DFT(x, fft_length=None, axis=1):\n", - " if axis == 1:\n", - " x = x.T\n", - " if fft_length is None:\n", - " fft_length = x.shape[0]\n", - " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n", - " if axis == 1:\n", - " return numpy.matmul(cst, x).T\n", - " return numpy.matmul(cst, x)\n", - "\n", - "def fft2d_(mat, fft_length):\n", - " mat = mat[:fft_length[0], :fft_length[1]]\n", - " res = mat.copy()\n", - " res = DFT(res, fft_length[1], axis=1)\n", - " res = DFT(res, fft_length[0], axis=0)\n", - " return res[:fft_length[0], :fft_length[1]//2 + 1]\n", - "\n", - "\n", - "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", - "fft2d_np_ = fft2d_(rnd, rnd.shape)\n", - "fft2d_np = numpy.fft.rfft2(rnd)\n", - "fft2d_np_" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "777d2775", - "metadata": {}, - "outputs": [], - "source": [ - "almost_equal(fft2d_np_, fft2d_np)" - ] - }, - { - "cell_type": "markdown", - "id": "cfbbe2fd", - "metadata": {}, - "source": [ - "It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n", - "\n", - "* $y = FFT(x, axis=1)$\n", - "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n", - "* $z = z_r + i z_i$\n", - "\n", - "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "dd4fc711", - "metadata": {}, - "outputs": [], - "source": [ - "def fft2d(mat, fft_length):\n", - " mat = mat[:fft_length[0], :fft_length[1]]\n", - " res = mat.copy()\n", - " \n", - " # first FFT\n", - " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :fft_length[0], :size]\n", - "\n", - "\n", - "fft2d_np = numpy.fft.rfft2(rnd)\n", - "fft2d_cus = fft2d(rnd, rnd.shape)\n", - "almost_equal(fft2d_np, fft2d_cus)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "bb8667e6", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[ 5.56511808+0.j , 2.40434541-6.58876113j,\n", - " -2.99787318+6.09018702j, 2.95547828-4.78324036j],\n", - " [-1.46413093-9.60314657j, -1.66675694+0.60494258j,\n", - " -2.49781725+0.06244585j, -4.22491665-2.62906297j],\n", - " [-1.06488187-4.9975721j , -7.53624925+2.11727626j,\n", - " -2.93212515+2.35814643j, -1.32906648-6.29456206j],\n", - " [-1.06488187+4.9975721j , 1.28429404-3.79395468j,\n", - " -4.70865157+7.14245256j, 4.69409373-0.6235566j ],\n", - " [-1.46413093+9.60314657j, -3.6701756 -3.08301071j,\n", - " -0.67943963-5.20582724j, 5.05462128+3.35367375j]])" + "cell_type": "markdown", + "id": "31a6ac9c", + "metadata": {}, + "source": [ + "## RFFT in ONNX\n", + "\n", + "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant." ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_np" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "56a94d97", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[[ 5.56511808, 2.40434541, -2.99787318, 2.95547828],\n", - " [-1.46413093, -1.66675694, -2.49781725, -4.22491665],\n", - " [-1.06488187, -7.53624925, -2.93212515, -1.32906648],\n", - " [-1.06488187, 1.28429404, -4.70865157, 4.69409373],\n", - " [-1.46413093, -3.6701756 , -0.67943963, 5.05462128]],\n", - "\n", - " [[ 0. , -6.58876113, 6.09018702, -4.78324036],\n", - " [-9.60314657, 0.60494258, 0.06244585, -2.62906297],\n", - " [-4.9975721 , 2.11727626, 2.35814643, -6.29456206],\n", - " [ 4.9975721 , -3.79395468, 7.14245256, -0.6235566 ],\n", - " [ 9.60314657, -3.08301071, -5.20582724, 3.35367375]]])" + "cell_type": "code", + "execution_count": 8, + "id": "efb67b9b", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[-1.3641814 , 2.3512614 , -3.4774065 , -1.7305996 ],\n", + " [ 0.6441313 , -0.87221646, 1.0705215 , 1.3161929 ],\n", + " [ 1.3891523 , -1.1198004 , -0.25900146, -1.4511662 ],\n", + " [-1.8638077 , 2.3779864 , 1.425402 , 0.18057205],\n", + " [ 4.1150527 , -3.3563478 , -0.38524887, -0.31538135]],\n", + "\n", + " [[ 0. , -1.4777277 , 3.402579 , 1.6450533 ],\n", + " [ 0. , -1.8795203 , -1.1863071 , 5.605154 ],\n", + " [ 0. , 2.8774288 , 0.17339343, 1.2479873 ],\n", + " [ 0. , 1.7200862 , 1.5771378 , 1.3203983 ],\n", + " [ 0. , -0.41940016, -1.39454 , 1.7538376 ]]],\n", + " dtype=float32)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import Any\n", + "import mlprodict.npy.numpy_onnx_impl as npnx\n", + "from mlprodict.npy import onnxnumpy_np\n", + "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n", + "# from mlprodict.onnxrt import OnnxInference\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft(x, fft_length=None):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " xt = npnx.transpose(x, (1, 0))\n", + " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", + " return npnx.transpose(res, (0, 2, 1))\n", + "\n", + "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n", + "fft_onx" ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_cus" - ] - }, - { - "cell_type": "markdown", - "id": "faa21909", - "metadata": {}, - "source": [ - "And with a different `fft_length`." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "bf98995f", - "metadata": {}, - "outputs": [], - "source": [ - "fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n", - "fft2d_cus = fft2d(rnd, (4, 6))\n", - "almost_equal(fft2d_np[:4, :], fft2d_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "caee1f84", - "metadata": {}, - "source": [ - "## FFT 2D in ONNX\n", - "\n", - "We use again the numpy API for ONNX." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "ca641274", - "metadata": {}, - "outputs": [], - "source": [ - "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " if transpose:\n", - " xt = npnx.transpose(x, (1, 0))\n", - " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", - " return npnx.transpose(res, (0, 2, 1))\n", - " else:\n", - " return npnx.matmul(cst[:, :, :fft_length], x[:fft_length])\n", - "\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft_2d(x, fft_length=None):\n", - " mat = x[:fft_length[0], :fft_length[1]]\n", - " \n", - " # first FFT\n", - " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :fft_length[0], :size]\n", - "\n", - "\n", - "fft2d_cus = fft2d(rnd, rnd.shape)\n", - "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "20fcd8a9", - "metadata": {}, - "source": [ - "The corresponding ONNX graph." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "b1379b06", - "metadata": { - "scrolled": false - }, - "outputs": [ + }, { - "data": { - "text/html": [ - "
\n", - "" + "cell_type": "code", + "execution_count": 9, + "id": "c4b6b1a5", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft_cus, fft_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "a8c35327", + "metadata": {}, + "source": [ + "The corresponding ONNX graph is the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "4d1a85b0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "key = list(onnx_rfft.signed_compiled)[0]\n", + "%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_" ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft_2d.signed_compiled)[0]\n", - "%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "3034da60", - "metadata": {}, - "outputs": [], - "source": [ - "with open(\"fft2d.onnx\", \"wb\") as f:\n", - " key = list(onnx_rfft_2d.signed_compiled)[0]\n", - " f.write(onnx_rfft_2d.signed_compiled[key].compiled.onnx_.SerializeToString())" - ] - }, - { - "cell_type": "markdown", - "id": "3a747f0c", - "metadata": {}, - "source": [ - "With a different `fft_length`." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "16732cbb", - "metadata": {}, - "outputs": [], - "source": [ - "fft2d_cus = fft2d(rnd, (4, 5))\n", - "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "04924e7d", - "metadata": {}, - "source": [ - "This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise." - ] - }, - { - "cell_type": "markdown", - "id": "c9da88a0", - "metadata": {}, - "source": [ - "## FFT2D with shape (3,1,4)\n", - "\n", - "Previous implementation expects the input matrix to have two dimensions. It fails with 3." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "66ba70ee", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "(3, 1, 4)" + "cell_type": "code", + "execution_count": 11, + "id": "6cf18aca", + "metadata": {}, + "outputs": [], + "source": [ + "fft_onx3 = onnx_rfft(rnd, fft_length=3)\n", + "almost_equal(fft_cus3, fft_onx3)" ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "shape = (3, 1, 4)\n", - "fft_length = (1, 4)\n", - "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", - "fft2d_numpy.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "a4d123e1", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "array([[[ 1.62552961+0.j , -2.33151346-0.26713149j,\n", - " 1.52621416+0.j , -2.33151346+0.26713149j]],\n", - "\n", - " [[ 1.56267625+0.j , -2.11182106+0.97715026j,\n", - " -1.59615904+0.j , -2.11182106-0.97715026j]],\n", - "\n", - " [[-2.11940277+0.j , 2.92459655+2.19828379j,\n", - " -1.98709261+0.j , 2.92459655-2.19828379j]]])" + "cell_type": "markdown", + "id": "6b466fd4", + "metadata": {}, + "source": [ + "## FFT 2D\n", + "\n", + "Below the code for complex features." ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fft2d_numpy" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "4b1bd05b", - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "axes don't match array\n" - ] - } - ], - "source": [ - "try:\n", - " fft2d_cus = fft2d(rnd, fft_length)\n", - "except Exception as e:\n", - " print(e)\n", - "# fft2d_onx = onnx_rfft_2d(rnd, fft_length=fft_length)" - ] - }, - { - "cell_type": "markdown", - "id": "7bd79a00", - "metadata": {}, - "source": [ - "### numpy version\n", - "\n", - "Let's do it again with numpy first. [fft2](https://numpy.org/doc/stable/reference/generated/numpy.fft.fft2.html) performs `fft2` on the last two axis as many times as the first axis. The goal is still to have an implementation which works for any dimension." - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "3b618335", - "metadata": {}, - "outputs": [], - "source": [ - "conc = []\n", - "for i in range(rnd.shape[0]):\n", - " f2 = fft2d(rnd[i], fft_length)\n", - " conc.append(numpy.expand_dims(f2, 0))\n", - "res = numpy.vstack(conc).transpose(1, 0, 2, 3)\n", - "almost_equal(fft2d_numpy[:, :, :3], res)" - ] - }, - { - "cell_type": "markdown", - "id": "7c837e7a", - "metadata": {}, - "source": [ - "It works. And now a more efficient implementation. It is better to read [matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) description before. To summarize, a third axis is equivalent to many matrix multiplications over the last two axes, as many as the dimension of the first axis: ``matmul(A[I,J,K], B[I,K,L]) --> C[I,J,L]``. Broadcasting also works... ``matmul(A[1,J,K], B[I,K,L]) --> C[I,J,L]``." - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "29055cb2", - "metadata": {}, - "outputs": [], - "source": [ - "def dft_real_d3(x, fft_length=None, transpose=True):\n", - " if len(x.shape) != 3:\n", - " raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\n", - " N = x.shape[1]\n", - " C = x.shape[-1] if transpose else x.shape[-2]\n", - " if fft_length is None:\n", - " fft_length = x.shape[-1]\n", - " size = fft_length // 2 + 1\n", - "\n", - " cst = dft_real_cst(C, fft_length)\n", - " if transpose:\n", - " x = numpy.transpose(x, (0, 2, 1))\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:, :fft_length, :]\n", - " a = numpy.expand_dims(a, 0)\n", - " b = numpy.expand_dims(b, 1)\n", - " res = numpy.matmul(a, b)\n", - " res = res[:, :, :size, :]\n", - " return numpy.transpose(res, (1, 0, 3, 2))\n", - " else:\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:, :fft_length, :]\n", - " a = numpy.expand_dims(a, 0)\n", - " b = numpy.expand_dims(b, 1)\n", - " res = numpy.matmul(a, b)\n", - " return numpy.transpose(res, (1, 0, 2, 3))\n", - "\n", - "\n", - "def fft2d_d3(mat, fft_length):\n", - " mat = mat[:, :fft_length[-2], :fft_length[-1]]\n", - " res = mat.copy()\n", - " \n", - " # first FFT\n", - " res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\n", - " res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\n", - " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[-1]//2 + 1\n", - " return res[:, :, :fft_length[-2], :size]\n", - "\n", - "\n", - "def fft2d_any(mat, fft_length):\n", - " new_shape = (-1, ) + mat.shape[-2:]\n", - " mat2 = mat.reshape(new_shape)\n", - " f2 = fft2d_d3(mat2, fft_length)\n", - " new_shape = (2, ) + mat.shape[:-2] + f2.shape[-2:]\n", - " return f2.reshape(new_shape)\n", - "\n", - "\n", - "shape = (3, 1, 4)\n", - "fft_length = (1, 4)\n", - "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", - "fft2d_cus = fft2d_any(rnd, fft_length)\n", - "almost_equal(fft2d_numpy[..., :3], fft2d_cus)" - ] - }, - { - "cell_type": "markdown", - "id": "0128b3f2", - "metadata": {}, - "source": [ - "We check with more shapes to see if the implementation works for all of them." - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "82f5fc78", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 12, + "id": "e0020084", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 5.56511808+0.j , 2.40434541-6.58876113j,\n", + " -2.99787318+6.09018702j, 2.95547828-4.78324036j],\n", + " [-1.46413093-9.60314657j, -1.66675694+0.60494258j,\n", + " -2.49781725+0.06244585j, -4.22491665-2.62906297j],\n", + " [-1.06488187-4.9975721j , -7.53624925+2.11727626j,\n", + " -2.93212515+2.35814643j, -1.32906648-6.29456206j],\n", + " [-1.06488187+4.9975721j , 1.28429404-3.79395468j,\n", + " -4.70865157+7.14245256j, 4.69409373-0.6235566j ],\n", + " [-1.46413093+9.60314657j, -3.6701756 -3.08301071j,\n", + " -0.67943963-5.20582724j, 5.05462128+3.35367375j]])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def _DFT_cst(N, fft_length, trunc=True):\n", + " n = numpy.arange(N)\n", + " k = n.reshape((N, 1)).astype(numpy.float64)\n", + " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", + " return M[:fft_length // 2 + 1] if trunc else M\n", + "\n", + "def DFT(x, fft_length=None, axis=1):\n", + " if axis == 1:\n", + " x = x.T\n", + " if fft_length is None:\n", + " fft_length = x.shape[0]\n", + " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n", + " if axis == 1:\n", + " return numpy.matmul(cst, x).T\n", + " return numpy.matmul(cst, x)\n", + "\n", + "def fft2d_(mat, fft_length):\n", + " mat = mat[:fft_length[0], :fft_length[1]]\n", + " res = mat.copy()\n", + " res = DFT(res, fft_length[1], axis=1)\n", + " res = DFT(res, fft_length[0], axis=0)\n", + " return res[:fft_length[0], :fft_length[1]//2 + 1]\n", + "\n", + "\n", + "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", + "fft2d_np_ = fft2d_(rnd, rnd.shape)\n", + "fft2d_np = numpy.fft.rfft2(rnd)\n", + "fft2d_np_" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 1, 2) or (2, 3, 1, 2)\n", - "OK x.shape=(3, 1, 4) length=(1, 1) output shape=(3, 1, 1) or (2, 3, 1, 1)\n", - "OK x.shape=(5, 7) length=(5, 7) output shape=(5, 7) or (2, 5, 4)\n", - "OK x.shape=(5, 7) length=(1, 7) output shape=(1, 7) or (2, 1, 4)\n", - "OK x.shape=(5, 7) length=(2, 7) output shape=(2, 7) or (2, 2, 4)\n", - "OK x.shape=(5, 7) length=(5, 2) output shape=(5, 2) or (2, 5, 2)\n", - "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", - "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 5, 7) or (2, 3, 5, 4)\n", - "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 1, 7) or (2, 3, 1, 4)\n", - "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 2, 7) or (2, 3, 2, 4)\n", - "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 5, 2) or (2, 3, 5, 2)\n", - "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 3, 4) or (2, 3, 3, 3)\n", - "OK x.shape=(7, 5) length=(7, 5) output shape=(7, 5) or (2, 7, 3)\n", - "OK x.shape=(7, 5) length=(1, 5) output shape=(1, 5) or (2, 1, 3)\n", - "OK x.shape=(7, 5) length=(2, 5) output shape=(2, 5) or (2, 2, 3)\n", - "OK x.shape=(7, 5) length=(7, 2) output shape=(7, 2) or (2, 7, 2)\n", - "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" - ] - } - ], - "source": [ - "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", - " for fft_length in [shape[-2:], (1, shape[-1]),\n", - " (min(2, shape[-2]), shape[-1]),\n", - " (shape[-2], 2),\n", - " (min(3, shape[-2]), min(4, shape[-2]))]:\n", - " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - " fnp = numpy.fft.fft2(x, fft_length)\n", - " if len(fnp.shape) == 2:\n", - " fn= numpy.expand_dims(fnp, 0)\n", - " try:\n", - " cus = fft2d_any(x, fft_length)\n", - " except IndexError as e:\n", - " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", - " continue\n", - " try:\n", - " almost_equal(fnp[..., :cus.shape[-1]], cus)\n", - " except (AssertionError, IndexError) as e:\n", - " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", - " x.shape, fft_length, e, fnp.shape, cus.shape))\n", - " continue\n", - " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", - " x.shape, fft_length, fnp.shape, cus.shape))" - ] - }, - { - "cell_type": "markdown", - "id": "c5f5229a", - "metadata": {}, - "source": [ - "### ONNX version\n", - "\n", - "Let's look into the differences first." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "025c2d88", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext pyquickhelper" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "8a9d153c", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 13, + "id": "777d2775", + "metadata": {}, + "outputs": [], + "source": [ + "almost_equal(fft2d_np_, fft2d_np)" + ] + }, + { + "cell_type": "markdown", + "id": "cfbbe2fd", + "metadata": {}, + "source": [ + "It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n", + "\n", + "* $y = FFT(x, axis=1)$\n", + "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n", + "* $z = z_r + i z_i$\n", + "\n", + "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that." + ] + }, { - "data": { - "text/html": [ - "\n" + "cell_type": "code", + "execution_count": 14, + "id": "dd4fc711", + "metadata": {}, + "outputs": [], + "source": [ + "def fft2d(mat, fft_length):\n", + " mat = mat[:fft_length[0], :fft_length[1]]\n", + " res = mat.copy()\n", + " \n", + " # first FFT\n", + " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :fft_length[0], :size]\n", + "\n", + "\n", + "fft2d_np = numpy.fft.rfft2(rnd)\n", + "fft2d_cus = fft2d(rnd, rnd.shape)\n", + "almost_equal(fft2d_np, fft2d_cus)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "bb8667e6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 5.56511808+0.j , 2.40434541-6.58876113j,\n", + " -2.99787318+6.09018702j, 2.95547828-4.78324036j],\n", + " [-1.46413093-9.60314657j, -1.66675694+0.60494258j,\n", + " -2.49781725+0.06244585j, -4.22491665-2.62906297j],\n", + " [-1.06488187-4.9975721j , -7.53624925+2.11727626j,\n", + " -2.93212515+2.35814643j, -1.32906648-6.29456206j],\n", + " [-1.06488187+4.9975721j , 1.28429404-3.79395468j,\n", + " -4.70865157+7.14245256j, 4.69409373-0.6235566j ],\n", + " [-1.46413093+9.60314657j, -3.6701756 -3.08301071j,\n", + " -0.67943963-5.20582724j, 5.05462128+3.35367375j]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "fft2d_np" ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "%%html\n", - "" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "82664bc5", - "metadata": {}, - "outputs": [ + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "56a94d97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 5.56511808, 2.40434541, -2.99787318, 2.95547828],\n", + " [-1.46413093, -1.66675694, -2.49781725, -4.22491665],\n", + " [-1.06488187, -7.53624925, -2.93212515, -1.32906648],\n", + " [-1.06488187, 1.28429404, -4.70865157, 4.69409373],\n", + " [-1.46413093, -3.6701756 , -0.67943963, 5.05462128]],\n", + "\n", + " [[ 0. , -6.58876113, 6.09018702, -4.78324036],\n", + " [-9.60314657, 0.60494258, 0.06244585, -2.62906297],\n", + " [-4.9975721 , 2.11727626, 2.35814643, -6.29456206],\n", + " [ 4.9975721 , -3.79395468, 7.14245256, -0.6235566 ],\n", + " [ 9.60314657, -3.08301071, -5.20582724, 3.35367375]]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_cus" + ] + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 24/24 [00:00<00:00, 776.23it/s]\n" - ] + "cell_type": "markdown", + "id": "faa21909", + "metadata": {}, + "source": [ + "And with a different `fft_length`." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "bf98995f", + "metadata": {}, + "outputs": [], + "source": [ + "fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n", + "fft2d_cus = fft2d(rnd, (4, 6))\n", + "almost_equal(fft2d_np[:4, :], fft2d_cus)" + ] }, { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
00def dft_real(x, fft_length=None, transpose=True):def dft_real_d3(x, fft_length=None, transpose=True):
11 if len(x.shape) == 1: if len(x.shape) != 3:
22 x = x.reshape((1, -1)) raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)
3 N = 1
4 else:
53 N = x.shape[0] N = x.shape[1]
64 C = x.shape[-1] if transpose else x.shape[-2] C = x.shape[-1] if transpose else x.shape[-2]
75 if fft_length is None: if fft_length is None:
86 fft_length = x.shape[-1] fft_length = x.shape[-1]
97 size = fft_length // 2 + 1 size = fft_length // 2 + 1
108
119 cst = dft_real_cst(C, fft_length) cst = dft_real_cst(C, fft_length)
1210 if transpose: if transpose:
1311 x = numpy.transpose(x, (1, 0)) x = numpy.transpose(x, (0, 2, 1))
1412 a = cst[:, :, :fft_length] a = cst[:, :, :fft_length]
1513 b = x[:fft_length] b = x[:, :fft_length, :]
14 a = numpy.expand_dims(a, 0)
15 b = numpy.expand_dims(b, 1)
1616 res = numpy.matmul(a, b) res = numpy.matmul(a, b)
1717 res = res[:, :size, :] res = res[:, :, :size, :]
1818 return numpy.transpose(res, (0, 2, 1)) return numpy.transpose(res, (1, 0, 3, 2))
1919 else: else:
2020 a = cst[:, :, :fft_length] a = cst[:, :, :fft_length]
2121 b = x[:fft_length] b = x[:, :fft_length, :]
22 a = numpy.expand_dims(a, 0)
23 b = numpy.expand_dims(b, 1)
2224 return numpy.matmul(a, b) res = numpy.matmul(a, b)
25 return numpy.transpose(res, (1, 0, 2, 3))
2326
" + "cell_type": "markdown", + "id": "caee1f84", + "metadata": {}, + "source": [ + "## FFT 2D in ONNX\n", + "\n", + "We use again the numpy API for ONNX." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "ca641274", + "metadata": {}, + "outputs": [], + "source": [ + "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " if transpose:\n", + " xt = npnx.transpose(x, (1, 0))\n", + " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", + " return npnx.transpose(res, (0, 2, 1))\n", + " else:\n", + " return npnx.matmul(cst[:, :, :fft_length], x[:fft_length])\n", + "\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft_2d(x, fft_length=None):\n", + " mat = x[:fft_length[0], :fft_length[1]]\n", + " \n", + " # first FFT\n", + " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :fft_length[0], :size]\n", + "\n", + "\n", + "fft2d_cus = fft2d(rnd, rnd.shape)\n", + "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "20fcd8a9", + "metadata": {}, + "source": [ + "The corresponding ONNX graph." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b1379b06", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "key = list(onnx_rfft_2d.signed_compiled)[0]\n", + "%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_" ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import inspect\n", - "text1 = inspect.getsource(dft_real)\n", - "text2 = inspect.getsource(dft_real_d3)\n", - "%codediff text1 text2 --verbose 1 --two 1" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "cd7e14d4", - "metadata": {}, - "outputs": [ + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 15/15 [00:00<00:00, 1156.92it/s]\n" - ] + "cell_type": "code", + "execution_count": 20, + "id": "3034da60", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"fft2d.onnx\", \"wb\") as f:\n", + " key = list(onnx_rfft_2d.signed_compiled)[0]\n", + " f.write(onnx_rfft_2d.signed_compiled[key].compiled.onnx_.SerializeToString())" + ] }, { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
00def fft2d(mat, fft_length):def fft2d_d3(mat, fft_length):
11 mat = mat[:fft_length[0], :fft_length[1]] mat = mat[:, :fft_length[-2], :fft_length[-1]]
22 res = mat.copy() res = mat.copy()
33
44 # first FFT # first FFT
55 res = dft_real(res, fft_length=fft_length[1], transpose=True) res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)
66
77 # second FFT decomposed on FFT on real part and imaginary part # second FFT decomposed on FFT on real part and imaginary part
88 res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False) res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)
99 res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)
1010 res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]]) res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])
1111 res = res2_real + res2_imag2 res = res2_real + res2_imag2
1212 size = fft_length[1]//2 + 1 size = fft_length[-1]//2 + 1
1313 return res[:, :fft_length[0], :size] return res[:, :, :fft_length[-2], :size]
1414
" + "cell_type": "markdown", + "id": "3a747f0c", + "metadata": {}, + "source": [ + "With a different `fft_length`." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "16732cbb", + "metadata": {}, + "outputs": [], + "source": [ + "fft2d_cus = fft2d(rnd, (4, 5))\n", + "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "04924e7d", + "metadata": {}, + "source": [ + "This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise." + ] + }, + { + "cell_type": "markdown", + "id": "c9da88a0", + "metadata": {}, + "source": [ + "## FFT2D with shape (3,1,4)\n", + "\n", + "Previous implementation expects the input matrix to have two dimensions. It fails with 3." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "66ba70ee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 1, 4)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", + "fft2d_numpy.shape" ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "text1 = inspect.getsource(fft2d)\n", - "text2 = inspect.getsource(fft2d_d3)\n", - "%codediff text1 text2 --verbose 1 --two 1" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "51e7a4f7", - "metadata": { - "scrolled": false - }, - "outputs": [ + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\xavierdupre\\__home_\\GitHub\\mlprodict\\mlprodict\\npy\\onnx_numpy_wrapper.py:27: RuntimeWarning: Class 'onnxnumpy_nb_onnx_rfft_2d_any_None_None' overwritten in\n", - "'onnxnumpy_nb_onnx_rfft_2d_None_None, onnxnumpy_nb_onnx_rfft_2d_any_None_None, onnxnumpy_nb_onnx_rfft_None_None'\n", - "---\n", - "\n", - " warnings.warn( # pragma: no cover\n" - ] - } - ], - "source": [ - "def onnx_rfft_3d_1d(x, fft_length=None, transpose=True):\n", - " if fft_length is None:\n", - " raise RuntimeError(\"fft_length must be specified.\")\n", - " \n", - " size = fft_length // 2 + 1\n", - " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", - " if transpose:\n", - " xt = npnx.transpose(x, (0, 2, 1))\n", - " a = cst[:, :, :fft_length]\n", - " b = xt[:, :fft_length, :]\n", - " a = npnx.expand_dims(a, 0)\n", - " b = npnx.expand_dims(b, 1)\n", - " res = npnx.matmul(a, b)\n", - " res2 = res[:, :size, :]\n", - " return npnx.transpose(res2, (1, 0, 3, 2))\n", - " else:\n", - " a = cst[:, :, :fft_length]\n", - " b = x[:, :fft_length, :]\n", - " a = npnx.expand_dims(a, 0)\n", - " b = npnx.expand_dims(b, 1)\n", - " res = npnx.matmul(a, b)\n", - " return npnx.transpose(res, (1, 0, 2, 3)) \n", - " \n", - "\n", - "def onnx_rfft_3d_2d(x, fft_length=None):\n", - " mat = x[:, :fft_length[-2], :fft_length[-1]]\n", - " \n", - " # first FFT\n", - " res = onnx_rfft_3d_1d(mat, fft_length=fft_length[-1], transpose=True)\n", - " \n", - " # second FFT decomposed on FFT on real part and imaginary part\n", - " res2_real = onnx_rfft_3d_1d(res[0], fft_length=fft_length[0], transpose=False)\n", - " res2_imag = onnx_rfft_3d_1d(res[1], fft_length=fft_length[0], transpose=False) \n", - " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", - " res = res2_real + res2_imag2\n", - " size = fft_length[1]//2 + 1\n", - " return res[:, :, :fft_length[-2], :size]\n", - "\n", - "\n", - "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", - "def onnx_rfft_2d_any(x, fft_length=None):\n", - " new_shape = npnx.concat(\n", - " numpy.array([-1], dtype=numpy.int64), x.shape[-2:], axis=0)\n", - " mat2 = x.reshape(new_shape)\n", - " f2 = onnx_rfft_3d_2d(mat2, fft_length)\n", - " new_shape = npnx.concat(\n", - " numpy.array([2], dtype=numpy.int64), x.shape[:-2], f2.shape[-2:])\n", - " return f2.reshape(new_shape)\n", - "\n", - "\n", - "shape = (3, 1, 4)\n", - "fft_length = (1, 4)\n", - "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - "fft2d_cus = fft2d_any(rnd, fft_length)\n", - "fft2d_onx = onnx_rfft_2d_any(rnd, fft_length=fft_length)\n", - "almost_equal(fft2d_cus, fft2d_onx)" - ] - }, - { - "cell_type": "markdown", - "id": "37c45ae7", - "metadata": {}, - "source": [ - "Let's do the same comparison." - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "11c1e596", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 23, + "id": "a4d123e1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 1.62552961+0.j , -2.33151346-0.26713149j,\n", + " 1.52621416+0.j , -2.33151346+0.26713149j]],\n", + "\n", + " [[ 1.56267625+0.j , -2.11182106+0.97715026j,\n", + " -1.59615904+0.j , -2.11182106-0.97715026j]],\n", + "\n", + " [[-2.11940277+0.j , 2.92459655+2.19828379j,\n", + " -1.98709261+0.j , 2.92459655-2.19828379j]]])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fft2d_numpy" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", - "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 4) or (2, 3, 1, 2)\n", - "DIS x.shape=(3, 1, 4) length=(1, 1) error=AssertionError('Mismatch max diff=2.9777344341736463e+35 > 1e-05.') output shape=(3, 4) or (2, 3, 1, 1)\n", - "OK x.shape=(5, 7) length=(5, 7) output shape=(3, 4) or (2, 5, 4)\n", - "OK x.shape=(5, 7) length=(1, 7) output shape=(3, 4) or (2, 1, 4)\n", - "OK x.shape=(5, 7) length=(2, 7) output shape=(3, 4) or (2, 2, 4)\n", - "OK x.shape=(5, 7) length=(5, 2) output shape=(3, 4) or (2, 5, 2)\n", - "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", - "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 4) or (2, 3, 5, 4)\n", - "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 4) or (2, 3, 1, 4)\n", - "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 4) or (2, 3, 2, 4)\n", - "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 4) or (2, 3, 5, 2)\n", - "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3, 3)\n", - "OK x.shape=(7, 5) length=(7, 5) output shape=(3, 4) or (2, 7, 3)\n", - "OK x.shape=(7, 5) length=(1, 5) output shape=(3, 4) or (2, 1, 3)\n", - "OK x.shape=(7, 5) length=(2, 5) output shape=(3, 4) or (2, 2, 3)\n", - "OK x.shape=(7, 5) length=(7, 2) output shape=(3, 4) or (2, 7, 2)\n", - "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" - ] - } - ], - "source": [ - "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", - " for fft_length in [shape[-2:], (1, shape[-1]),\n", - " (min(2, shape[-2]), shape[-1]),\n", - " (shape[-2], 2),\n", - " (min(3, shape[-2]), min(4, shape[-2]))]:\n", - " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", - " if len(fnp.shape) == 2:\n", - " fn= numpy.expand_dims(fnp, 0)\n", - " try:\n", - " cus = fft2d_any(x, fft_length)\n", - " except IndexError as e:\n", - " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", - " continue\n", - " try:\n", - " onx = onnx_rfft_2d_any(x, fft_length=fft_length)\n", - " except IndexError as e:\n", - " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", - " continue\n", - " try:\n", - " almost_equal(onx, cus)\n", - " except (AssertionError, IndexError) as e:\n", - " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", - " x.shape, fft_length, e, fnp.shape, cus.shape))\n", - " continue\n", - " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", - " x.shape, fft_length, fnp.shape, cus.shape))" - ] - }, - { - "cell_type": "markdown", - "id": "d197467f", - "metadata": {}, - "source": [ - "There is one issue with ``fft_length=(1, 1)`` but that case is out of scope." - ] - }, - { - "cell_type": "markdown", - "id": "33b5897e", - "metadata": {}, - "source": [ - "### ONNX graph" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "d45e9a99", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 24, + "id": "4b1bd05b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "axes don't match array\n" + ] + } + ], + "source": [ + "try:\n", + " fft2d_cus = fft2d(rnd, fft_length)\n", + "except Exception as e:\n", + " print(e)\n", + "# fft2d_onx = onnx_rfft_2d(rnd, fft_length=fft_length)" + ] + }, + { + "cell_type": "markdown", + "id": "7bd79a00", + "metadata": {}, + "source": [ + "### numpy version\n", + "\n", + "Let's do it again with numpy first. [fft2](https://numpy.org/doc/stable/reference/generated/numpy.fft.fft2.html) performs `fft2` on the last two axis as many times as the first axis. The goal is still to have an implementation which works for any dimension." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3b618335", + "metadata": {}, + "outputs": [], + "source": [ + "conc = []\n", + "for i in range(rnd.shape[0]):\n", + " f2 = fft2d(rnd[i], fft_length)\n", + " conc.append(numpy.expand_dims(f2, 0))\n", + "res = numpy.vstack(conc).transpose(1, 0, 2, 3)\n", + "almost_equal(fft2d_numpy[:, :, :3], res)" + ] + }, + { + "cell_type": "markdown", + "id": "7c837e7a", + "metadata": {}, + "source": [ + "It works. And now a more efficient implementation. It is better to read [matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) description before. To summarize, a third axis is equivalent to many matrix multiplications over the last two axes, as many as the dimension of the first axis: ``matmul(A[I,J,K], B[I,K,L]) --> C[I,J,L]``. Broadcasting also works... ``matmul(A[1,J,K], B[I,K,L]) --> C[I,J,L]``." + ] + }, { - "data": { - "text/html": [ - "
\n", - "" + "cell_type": "code", + "execution_count": 26, + "id": "29055cb2", + "metadata": {}, + "outputs": [], + "source": [ + "def dft_real_d3(x, fft_length=None, transpose=True):\n", + " if len(x.shape) != 3:\n", + " raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\n", + " N = x.shape[1]\n", + " C = x.shape[-1] if transpose else x.shape[-2]\n", + " if fft_length is None:\n", + " fft_length = x.shape[-1]\n", + " size = fft_length // 2 + 1\n", + "\n", + " cst = dft_real_cst(C, fft_length)\n", + " if transpose:\n", + " x = numpy.transpose(x, (0, 2, 1))\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " res = res[:, :, :size, :]\n", + " return numpy.transpose(res, (1, 0, 3, 2))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = numpy.expand_dims(a, 0)\n", + " b = numpy.expand_dims(b, 1)\n", + " res = numpy.matmul(a, b)\n", + " return numpy.transpose(res, (1, 0, 2, 3))\n", + "\n", + "\n", + "def fft2d_d3(mat, fft_length):\n", + " mat = mat[:, :fft_length[-2], :fft_length[-1]]\n", + " res = mat.copy()\n", + " \n", + " # first FFT\n", + " res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\n", + " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[-1]//2 + 1\n", + " return res[:, :, :fft_length[-2], :size]\n", + "\n", + "\n", + "def fft2d_any(mat, fft_length):\n", + " new_shape = (-1, ) + mat.shape[-2:]\n", + " mat2 = mat.reshape(new_shape)\n", + " f2 = fft2d_d3(mat2, fft_length)\n", + " new_shape = (2, ) + mat.shape[:-2] + f2.shape[-2:]\n", + " return f2.reshape(new_shape)\n", + "\n", + "\n", + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", + "fft2d_cus = fft2d_any(rnd, fft_length)\n", + "almost_equal(fft2d_numpy[..., :3], fft2d_cus)" + ] + }, + { + "cell_type": "markdown", + "id": "0128b3f2", + "metadata": {}, + "source": [ + "We check with more shapes to see if the implementation works for all of them." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "82f5fc78", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 1, 2) or (2, 3, 1, 2)\n", + "OK x.shape=(3, 1, 4) length=(1, 1) output shape=(3, 1, 1) or (2, 3, 1, 1)\n", + "OK x.shape=(5, 7) length=(5, 7) output shape=(5, 7) or (2, 5, 4)\n", + "OK x.shape=(5, 7) length=(1, 7) output shape=(1, 7) or (2, 1, 4)\n", + "OK x.shape=(5, 7) length=(2, 7) output shape=(2, 7) or (2, 2, 4)\n", + "OK x.shape=(5, 7) length=(5, 2) output shape=(5, 2) or (2, 5, 2)\n", + "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", + "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 5, 7) or (2, 3, 5, 4)\n", + "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 1, 7) or (2, 3, 1, 4)\n", + "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 2, 7) or (2, 3, 2, 4)\n", + "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 5, 2) or (2, 3, 5, 2)\n", + "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 3, 4) or (2, 3, 3, 3)\n", + "OK x.shape=(7, 5) length=(7, 5) output shape=(7, 5) or (2, 7, 3)\n", + "OK x.shape=(7, 5) length=(1, 5) output shape=(1, 5) or (2, 1, 3)\n", + "OK x.shape=(7, 5) length=(2, 5) output shape=(2, 5) or (2, 2, 3)\n", + "OK x.shape=(7, 5) length=(7, 2) output shape=(7, 2) or (2, 7, 2)\n", + "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", + " for fft_length in [shape[-2:], (1, shape[-1]),\n", + " (min(2, shape[-2]), shape[-1]),\n", + " (shape[-2], 2),\n", + " (min(3, shape[-2]), min(4, shape[-2]))]:\n", + " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + " fnp = numpy.fft.fft2(x, fft_length)\n", + " if len(fnp.shape) == 2:\n", + " fn= numpy.expand_dims(fnp, 0)\n", + " try:\n", + " cus = fft2d_any(x, fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " almost_equal(fnp[..., :cus.shape[-1]], cus)\n", + " except (AssertionError, IndexError) as e:\n", + " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, e, fnp.shape, cus.shape))\n", + " continue\n", + " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, fnp.shape, cus.shape))" ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", - "%onnxview onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "2ab7a3d0", - "metadata": {}, - "outputs": [], - "source": [ - "with open(\"fft2d_any.onnx\", \"wb\") as f:\n", - " key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", - " f.write(onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_.SerializeToString())" - ] - }, - { - "cell_type": "markdown", - "id": "3c17b577", - "metadata": {}, - "source": [ - "Let's check the intermediate results." - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "9e5507f7", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "FctVersion((numpy.float32,), ((1, 4),))" + "cell_type": "markdown", + "id": "c5f5229a", + "metadata": {}, + "source": [ + "### ONNX version\n", + "\n", + "Let's look into the differences first." ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", - "key" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "376036f8", - "metadata": { - "scrolled": false - }, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "+ki='Un_Unsqueezecst': (2, 1, 1) (dtype=float32 min=0.0 max=1.0)\n", - "+ki='Un_Unsqueezecst1': (1,) (dtype=int64 min=0 max=0)\n", - "+ki='Un_Unsqueezecst2': (2, 4, 4) (dtype=float32 min=-1.0 max=1.0)\n", - "+ki='Co_Concatcst': (1,) (dtype=int64 min=-1 max=-1)\n", - "+ki='Sl_Slicecst': (1,) (dtype=int64 min=-2 max=-2)\n", - "+ki='Sl_Slicecst2': (2,) (dtype=int64 min=0 max=0)\n", - "+ki='Sl_Slicecst3': (2,) (dtype=int64 min=1 max=4)\n", - "+ki='Sl_Slicecst4': (2,) (dtype=int64 min=1 max=2)\n", - "+ki='Sl_Slicecst6': (1,) (dtype=int64 min=4 max=4)\n", - "+ki='Sl_Slicecst7': (1,) (dtype=int64 min=1 max=1)\n", - "+ki='Sl_Slicecst9': (1,) (dtype=int64 min=3 max=3)\n", - "+ki='Ga_Gathercst1': () (dtype=int64 min=0 max=0)\n", - "+ki='Ga_Gathercst2': () (dtype=int64 min=1 max=1)\n", - "+ki='Sl_Slicecst18': (1,) (dtype=int64 min=2 max=2)\n", - "+ki='Sl_Slicecst24': (2,) (dtype=int64 min=1 max=3)\n", - "+ki='Sl_Slicecst25': (2,) (dtype=int64 min=2 max=3)\n", - "-- OnnxInference: run 38 nodes\n", - "Onnx-Unsqueeze(Un_Unsqueezecst, Un_Unsqueezecst1) -> Un_expanded0 (name='Un_Unsqueeze')\n", - "+kr='Un_expanded0': (1, 2, 1, 1) (dtype=float32 min=0.0 max=1.0)\n", - "Onnx-Unsqueeze(Un_Unsqueezecst2, Un_Unsqueezecst1) -> Un_expanded03 (name='Un_Unsqueeze1')\n", - "+kr='Un_expanded03': (1, 2, 4, 4) (dtype=float32 min=-1.0 max=1.0)\n", - "Onnx-Shape(x) -> Sh_shape0 (name='Sh_Shape')\n", - "+kr='Sh_shape0': (3,) (dtype=int64 min=1 max=4)\n", - "Onnx-Shape(Sh_shape0) -> Sh_shape01 (name='Sh_Shape1')\n", - "+kr='Sh_shape01': (1,) (dtype=int64 min=3 max=3)\n", - "Onnx-Gather(Sh_shape01, Un_Unsqueezecst1) -> Ga_output01 (name='Ga_Gather')\n", - "+kr='Ga_output01': (1,) (dtype=int64 min=3 max=3)\n", - "Onnx-Slice(Sh_shape0, Sl_Slicecst, Ga_output01, Un_Unsqueezecst1) -> Sl_output05 (name='Sl_Slice')\n", - "+kr='Sl_output05': (2,) (dtype=int64 min=1 max=4)\n", - "Onnx-Concat(Co_Concatcst, Sl_output05) -> Co_concat_result0 (name='Co_Concat')\n", - "+kr='Co_concat_result0': (3,) (dtype=int64 min=-1 max=4)\n", - "Onnx-Reshape(x, Co_concat_result0) -> Re_reshaped0 (name='Re_Reshape')\n", - "+kr='Re_reshaped0': (3, 1, 4) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", - "Onnx-Slice(Re_reshaped0, Sl_Slicecst2, Sl_Slicecst3, Sl_Slicecst4) -> Sl_output04 (name='Sl_Slice1')\n", - "+kr='Sl_output04': (3, 1, 4) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", - "Onnx-Transpose(Sl_output04) -> Tr_transposed02 (name='Tr_Transpose')\n", - "+kr='Tr_transposed02': (3, 4, 1) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", - "Onnx-Slice(Tr_transposed02, Un_Unsqueezecst1, Sl_Slicecst6, Sl_Slicecst7) -> Sl_output03 (name='Sl_Slice2')\n", - "+kr='Sl_output03': (3, 4, 1) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", - "Onnx-Unsqueeze(Sl_output03, Sl_Slicecst7) -> Un_expanded04 (name='Un_Unsqueeze2')\n", - "+kr='Un_expanded04': (3, 1, 4, 1) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", - "Onnx-MatMul(Un_expanded03, Un_expanded04) -> Ma_Y01 (name='Ma_MatMul')\n", - "+kr='Ma_Y01': (3, 2, 4, 1) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-Slice(Ma_Y01, Un_Unsqueezecst1, Sl_Slicecst9, Sl_Slicecst7) -> Sl_output02 (name='Sl_Slice3')\n", - "+kr='Sl_output02': (3, 2, 4, 1) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-Transpose(Sl_output02) -> Tr_transposed01 (name='Tr_Transpose1')\n", - "+kr='Tr_transposed01': (2, 3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-Gather(Tr_transposed01, Ga_Gathercst1) -> Ga_output0 (name='Ga_Gather1')\n", - "+kr='Ga_output0': (3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-Slice(Ga_output0, Un_Unsqueezecst1, Sl_Slicecst7, Sl_Slicecst7) -> Sl_output01 (name='Sl_Slice4')\n", - "+kr='Sl_output01': (3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-Unsqueeze(Sl_output01, Sl_Slicecst7) -> Un_expanded02 (name='Un_Unsqueeze3')\n", - "+kr='Un_expanded02': (3, 1, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-MatMul(Un_expanded0, Un_expanded02) -> Ma_Y0 (name='Ma_MatMul1')\n", - "+kr='Ma_Y0': (3, 2, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-Transpose(Ma_Y0) -> Tr_transposed0 (name='Tr_Transpose2')\n", - "+kr='Tr_transposed0': (2, 3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-Gather(Tr_transposed01, Ga_Gathercst2) -> Ga_output03 (name='Ga_Gather2')\n", - "+kr='Ga_output03': (3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", - "Onnx-Slice(Ga_output03, Un_Unsqueezecst1, Sl_Slicecst7, Sl_Slicecst7) -> Sl_output07 (name='Sl_Slice5')\n", - "+kr='Sl_output07': (3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", - "Onnx-Unsqueeze(Sl_output07, Sl_Slicecst7) -> Un_expanded06 (name='Un_Unsqueeze5')\n", - "+kr='Un_expanded06': (3, 1, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", - "Onnx-MatMul(Un_expanded0, Un_expanded06) -> Ma_Y03 (name='Ma_MatMul2')\n", - "+kr='Ma_Y03': (3, 2, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", - "Onnx-Transpose(Ma_Y03) -> Tr_transposed04 (name='Tr_Transpose3')\n", - "+kr='Tr_transposed04': (2, 3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", - "Onnx-Slice(Tr_transposed04, Sl_Slicecst7, Sl_Slicecst18, Un_Unsqueezecst1) -> Sl_output06 (name='Sl_Slice6')\n", - "+kr='Sl_output06': (1, 3, 1, 4) (dtype=float32 min=0.0 max=0.0)\n", - "Onnx-Neg(Sl_output06) -> Ne_Y0 (name='Ne_Neg')\n", - "+kr='Ne_Y0': (1, 3, 1, 4) (dtype=float32 min=-0.0 max=-0.0)\n", - "Onnx-Slice(Tr_transposed04, Un_Unsqueezecst1, Sl_Slicecst7, Un_Unsqueezecst1) -> Sl_output08 (name='Sl_Slice7')\n", - "+kr='Sl_output08': (1, 3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", - "Onnx-Concat(Ne_Y0, Sl_output08) -> Co_concat_result02 (name='Co_Concat1')\n", - "+kr='Co_concat_result02': (2, 3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", - "Onnx-Add(Tr_transposed0, Co_concat_result02) -> Ad_C0 (name='Ad_Add')\n", - "+kr='Ad_C0': (2, 3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-Slice(Ad_C0, Sl_Slicecst2, Sl_Slicecst24, Sl_Slicecst25) -> Sl_output0 (name='Sl_Slice8')\n", - "+kr='Sl_output0': (2, 3, 1, 3) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", - "Onnx-Slice(Sh_shape0, Un_Unsqueezecst1, Sl_Slicecst, Un_Unsqueezecst1) -> Sl_output010 (name='Sl_Slice9')\n", - "+kr='Sl_output010': (1,) (dtype=int64 min=3 max=3)\n", - "Onnx-Shape(Sl_output0) -> Sh_shape03 (name='Sh_Shape3')\n", - "+kr='Sh_shape03': (4,) (dtype=int64 min=1 max=3)\n", - "Onnx-Shape(Sh_shape03) -> Sh_shape04 (name='Sh_Shape4')\n", - "+kr='Sh_shape04': (1,) (dtype=int64 min=4 max=4)\n", - "Onnx-Gather(Sh_shape04, Un_Unsqueezecst1) -> Ga_output04 (name='Ga_Gather3')\n", - "+kr='Ga_output04': (1,) (dtype=int64 min=4 max=4)\n", - "Onnx-Slice(Sh_shape03, Sl_Slicecst, Ga_output04, Un_Unsqueezecst1) -> Sl_output012 (name='Sl_Slice10')\n", - "+kr='Sl_output012': (2,) (dtype=int64 min=1 max=3)\n", - "Onnx-Concat(Sl_Slicecst18, Sl_output010, Sl_output012) -> Co_concat_result03 (name='Co_Concat2')\n", - "+kr='Co_concat_result03': (4,) (dtype=int64 min=1 max=3)\n", - "Onnx-Reshape(Sl_output0, Co_concat_result03) -> y (name='Re_Reshape1')\n", - "+kr='y': (2, 3, 1, 3) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n" - ] + "cell_type": "code", + "execution_count": 28, + "id": "025c2d88", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext pyquickhelper" + ] }, { - "data": { - "text/plain": [ - "{'y': array([[[[ 1.0642704e+00, 7.8808188e-02, 3.1808631e+00]],\n", - " \n", - " [[-1.7878022e+00, -2.5084746e+00, 5.4854429e-01]],\n", - " \n", - " [[-2.2876425e+00, 8.1763226e-01, 4.4160408e-01]]],\n", - " \n", - " \n", - " [[[ 0.0000000e+00, -2.1299846e+00, 7.7034396e-16]],\n", - " \n", - " [[ 0.0000000e+00, 1.2344277e-01, 5.0231944e-16]],\n", - " \n", - " [[ 0.0000000e+00, 1.0373981e+00, -5.9766380e-18]]]],\n", - " dtype=float32)}" + "cell_type": "code", + "execution_count": 29, + "id": "8a9d153c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%html\n", + "" ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "82664bc5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:00<00:00, 776.23it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
00def dft_real(x, fft_length=None, transpose=True):def dft_real_d3(x, fft_length=None, transpose=True):
11 if len(x.shape) == 1: if len(x.shape) != 3:
22 x = x.reshape((1, -1)) raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)
3 N = 1
4 else:
53 N = x.shape[0] N = x.shape[1]
64 C = x.shape[-1] if transpose else x.shape[-2] C = x.shape[-1] if transpose else x.shape[-2]
75 if fft_length is None: if fft_length is None:
86 fft_length = x.shape[-1] fft_length = x.shape[-1]
97 size = fft_length // 2 + 1 size = fft_length // 2 + 1
108
119 cst = dft_real_cst(C, fft_length) cst = dft_real_cst(C, fft_length)
1210 if transpose: if transpose:
1311 x = numpy.transpose(x, (1, 0)) x = numpy.transpose(x, (0, 2, 1))
1412 a = cst[:, :, :fft_length] a = cst[:, :, :fft_length]
1513 b = x[:fft_length] b = x[:, :fft_length, :]
14 a = numpy.expand_dims(a, 0)
15 b = numpy.expand_dims(b, 1)
1616 res = numpy.matmul(a, b) res = numpy.matmul(a, b)
1717 res = res[:, :size, :] res = res[:, :, :size, :]
1818 return numpy.transpose(res, (0, 2, 1)) return numpy.transpose(res, (1, 0, 3, 2))
1919 else: else:
2020 a = cst[:, :, :fft_length] a = cst[:, :, :fft_length]
2121 b = x[:fft_length] b = x[:, :fft_length, :]
22 a = numpy.expand_dims(a, 0)
23 b = numpy.expand_dims(b, 1)
2224 return numpy.matmul(a, b) res = numpy.matmul(a, b)
25 return numpy.transpose(res, (1, 0, 2, 3))
2326
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import inspect\n", + "text1 = inspect.getsource(dft_real)\n", + "text2 = inspect.getsource(dft_real_d3)\n", + "%codediff text1 text2 --verbose 1 --two 1" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "cd7e14d4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 15/15 [00:00<00:00, 1156.92it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
00def fft2d(mat, fft_length):def fft2d_d3(mat, fft_length):
11 mat = mat[:fft_length[0], :fft_length[1]] mat = mat[:, :fft_length[-2], :fft_length[-1]]
22 res = mat.copy() res = mat.copy()
33
44 # first FFT # first FFT
55 res = dft_real(res, fft_length=fft_length[1], transpose=True) res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)
66
77 # second FFT decomposed on FFT on real part and imaginary part # second FFT decomposed on FFT on real part and imaginary part
88 res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False) res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)
99 res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)
1010 res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]]) res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])
1111 res = res2_real + res2_imag2 res = res2_real + res2_imag2
1212 size = fft_length[1]//2 + 1 size = fft_length[-1]//2 + 1
1313 return res[:, :fft_length[0], :size] return res[:, :, :fft_length[-2], :size]
1414
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text1 = inspect.getsource(fft2d)\n", + "text2 = inspect.getsource(fft2d_d3)\n", + "%codediff text1 text2 --verbose 1 --two 1" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "51e7a4f7", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\xavierdupre\\__home_\\GitHub\\mlprodict\\mlprodict\\npy\\onnx_numpy_wrapper.py:27: RuntimeWarning: Class 'onnxnumpy_nb_onnx_rfft_2d_any_None_None' overwritten in\n", + "'onnxnumpy_nb_onnx_rfft_2d_None_None, onnxnumpy_nb_onnx_rfft_2d_any_None_None, onnxnumpy_nb_onnx_rfft_None_None'\n", + "---\n", + "\n", + " warnings.warn( # pragma: no cover\n" + ] + } + ], + "source": [ + "def onnx_rfft_3d_1d(x, fft_length=None, transpose=True):\n", + " if fft_length is None:\n", + " raise RuntimeError(\"fft_length must be specified.\")\n", + " \n", + " size = fft_length // 2 + 1\n", + " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", + " if transpose:\n", + " xt = npnx.transpose(x, (0, 2, 1))\n", + " a = cst[:, :, :fft_length]\n", + " b = xt[:, :fft_length, :]\n", + " a = npnx.expand_dims(a, 0)\n", + " b = npnx.expand_dims(b, 1)\n", + " res = npnx.matmul(a, b)\n", + " res2 = res[:, :size, :]\n", + " return npnx.transpose(res2, (1, 0, 3, 2))\n", + " else:\n", + " a = cst[:, :, :fft_length]\n", + " b = x[:, :fft_length, :]\n", + " a = npnx.expand_dims(a, 0)\n", + " b = npnx.expand_dims(b, 1)\n", + " res = npnx.matmul(a, b)\n", + " return npnx.transpose(res, (1, 0, 2, 3)) \n", + " \n", + "\n", + "def onnx_rfft_3d_2d(x, fft_length=None):\n", + " mat = x[:, :fft_length[-2], :fft_length[-1]]\n", + " \n", + " # first FFT\n", + " res = onnx_rfft_3d_1d(mat, fft_length=fft_length[-1], transpose=True)\n", + " \n", + " # second FFT decomposed on FFT on real part and imaginary part\n", + " res2_real = onnx_rfft_3d_1d(res[0], fft_length=fft_length[0], transpose=False)\n", + " res2_imag = onnx_rfft_3d_1d(res[1], fft_length=fft_length[0], transpose=False) \n", + " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", + " res = res2_real + res2_imag2\n", + " size = fft_length[1]//2 + 1\n", + " return res[:, :, :fft_length[-2], :size]\n", + "\n", + "\n", + "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", + "def onnx_rfft_2d_any(x, fft_length=None):\n", + " new_shape = npnx.concat(\n", + " numpy.array([-1], dtype=numpy.int64), x.shape[-2:], axis=0)\n", + " mat2 = x.reshape(new_shape)\n", + " f2 = onnx_rfft_3d_2d(mat2, fft_length)\n", + " new_shape = npnx.concat(\n", + " numpy.array([2], dtype=numpy.int64), x.shape[:-2], f2.shape[-2:])\n", + " return f2.reshape(new_shape)\n", + "\n", + "\n", + "shape = (3, 1, 4)\n", + "fft_length = (1, 4)\n", + "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + "fft2d_cus = fft2d_any(rnd, fft_length)\n", + "fft2d_onx = onnx_rfft_2d_any(rnd, fft_length=fft_length)\n", + "almost_equal(fft2d_cus, fft2d_onx)" + ] + }, + { + "cell_type": "markdown", + "id": "37c45ae7", + "metadata": {}, + "source": [ + "Let's do the same comparison." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "11c1e596", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", + "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 4) or (2, 3, 1, 2)\n", + "DIS x.shape=(3, 1, 4) length=(1, 1) error=AssertionError('Mismatch max diff=2.9777344341736463e+35 > 1e-05.') output shape=(3, 4) or (2, 3, 1, 1)\n", + "OK x.shape=(5, 7) length=(5, 7) output shape=(3, 4) or (2, 5, 4)\n", + "OK x.shape=(5, 7) length=(1, 7) output shape=(3, 4) or (2, 1, 4)\n", + "OK x.shape=(5, 7) length=(2, 7) output shape=(3, 4) or (2, 2, 4)\n", + "OK x.shape=(5, 7) length=(5, 2) output shape=(3, 4) or (2, 5, 2)\n", + "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", + "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 4) or (2, 3, 5, 4)\n", + "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 4) or (2, 3, 1, 4)\n", + "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 4) or (2, 3, 2, 4)\n", + "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 4) or (2, 3, 5, 2)\n", + "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3, 3)\n", + "OK x.shape=(7, 5) length=(7, 5) output shape=(3, 4) or (2, 7, 3)\n", + "OK x.shape=(7, 5) length=(1, 5) output shape=(3, 4) or (2, 1, 3)\n", + "OK x.shape=(7, 5) length=(2, 5) output shape=(3, 4) or (2, 2, 3)\n", + "OK x.shape=(7, 5) length=(7, 2) output shape=(3, 4) or (2, 7, 2)\n", + "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n" + ] + } + ], + "source": [ + "for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", + " for fft_length in [shape[-2:], (1, shape[-1]),\n", + " (min(2, shape[-2]), shape[-1]),\n", + " (shape[-2], 2),\n", + " (min(3, shape[-2]), min(4, shape[-2]))]:\n", + " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", + " if len(fnp.shape) == 2:\n", + " fn= numpy.expand_dims(fnp, 0)\n", + " try:\n", + " cus = fft2d_any(x, fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " onx = onnx_rfft_2d_any(x, fft_length=fft_length)\n", + " except IndexError as e:\n", + " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", + " continue\n", + " try:\n", + " almost_equal(onx, cus)\n", + " except (AssertionError, IndexError) as e:\n", + " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, e, fnp.shape, cus.shape))\n", + " continue\n", + " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", + " x.shape, fft_length, fnp.shape, cus.shape))" + ] + }, + { + "cell_type": "markdown", + "id": "d197467f", + "metadata": {}, + "source": [ + "There is one issue with ``fft_length=(1, 1)`` but that case is out of scope." + ] + }, + { + "cell_type": "markdown", + "id": "33b5897e", + "metadata": {}, + "source": [ + "### ONNX graph" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "d45e9a99", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + "%onnxview onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "2ab7a3d0", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"fft2d_any.onnx\", \"wb\") as f:\n", + " key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + " f.write(onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_.SerializeToString())" + ] + }, + { + "cell_type": "markdown", + "id": "3c17b577", + "metadata": {}, + "source": [ + "Let's check the intermediate results." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "9e5507f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "FctVersion((numpy.float32,), ((1, 4),))" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", + "key" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "376036f8", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+ki='Un_Unsqueezecst': (2, 1, 1) (dtype=float32 min=0.0 max=1.0)\n", + "+ki='Un_Unsqueezecst1': (1,) (dtype=int64 min=0 max=0)\n", + "+ki='Un_Unsqueezecst2': (2, 4, 4) (dtype=float32 min=-1.0 max=1.0)\n", + "+ki='Co_Concatcst': (1,) (dtype=int64 min=-1 max=-1)\n", + "+ki='Sl_Slicecst': (1,) (dtype=int64 min=-2 max=-2)\n", + "+ki='Sl_Slicecst2': (2,) (dtype=int64 min=0 max=0)\n", + "+ki='Sl_Slicecst3': (2,) (dtype=int64 min=1 max=4)\n", + "+ki='Sl_Slicecst4': (2,) (dtype=int64 min=1 max=2)\n", + "+ki='Sl_Slicecst6': (1,) (dtype=int64 min=4 max=4)\n", + "+ki='Sl_Slicecst7': (1,) (dtype=int64 min=1 max=1)\n", + "+ki='Sl_Slicecst9': (1,) (dtype=int64 min=3 max=3)\n", + "+ki='Ga_Gathercst1': () (dtype=int64 min=0 max=0)\n", + "+ki='Ga_Gathercst2': () (dtype=int64 min=1 max=1)\n", + "+ki='Sl_Slicecst18': (1,) (dtype=int64 min=2 max=2)\n", + "+ki='Sl_Slicecst24': (2,) (dtype=int64 min=1 max=3)\n", + "+ki='Sl_Slicecst25': (2,) (dtype=int64 min=2 max=3)\n", + "-- OnnxInference: run 38 nodes\n", + "Onnx-Unsqueeze(Un_Unsqueezecst, Un_Unsqueezecst1) -> Un_expanded0 (name='Un_Unsqueeze')\n", + "+kr='Un_expanded0': (1, 2, 1, 1) (dtype=float32 min=0.0 max=1.0)\n", + "Onnx-Unsqueeze(Un_Unsqueezecst2, Un_Unsqueezecst1) -> Un_expanded03 (name='Un_Unsqueeze1')\n", + "+kr='Un_expanded03': (1, 2, 4, 4) (dtype=float32 min=-1.0 max=1.0)\n", + "Onnx-Shape(x) -> Sh_shape0 (name='Sh_Shape')\n", + "+kr='Sh_shape0': (3,) (dtype=int64 min=1 max=4)\n", + "Onnx-Shape(Sh_shape0) -> Sh_shape01 (name='Sh_Shape1')\n", + "+kr='Sh_shape01': (1,) (dtype=int64 min=3 max=3)\n", + "Onnx-Gather(Sh_shape01, Un_Unsqueezecst1) -> Ga_output01 (name='Ga_Gather')\n", + "+kr='Ga_output01': (1,) (dtype=int64 min=3 max=3)\n", + "Onnx-Slice(Sh_shape0, Sl_Slicecst, Ga_output01, Un_Unsqueezecst1) -> Sl_output05 (name='Sl_Slice')\n", + "+kr='Sl_output05': (2,) (dtype=int64 min=1 max=4)\n", + "Onnx-Concat(Co_Concatcst, Sl_output05) -> Co_concat_result0 (name='Co_Concat')\n", + "+kr='Co_concat_result0': (3,) (dtype=int64 min=-1 max=4)\n", + "Onnx-Reshape(x, Co_concat_result0) -> Re_reshaped0 (name='Re_Reshape')\n", + "+kr='Re_reshaped0': (3, 1, 4) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", + "Onnx-Slice(Re_reshaped0, Sl_Slicecst2, Sl_Slicecst3, Sl_Slicecst4) -> Sl_output04 (name='Sl_Slice1')\n", + "+kr='Sl_output04': (3, 1, 4) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", + "Onnx-Transpose(Sl_output04) -> Tr_transposed02 (name='Tr_Transpose')\n", + "+kr='Tr_transposed02': (3, 4, 1) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", + "Onnx-Slice(Tr_transposed02, Un_Unsqueezecst1, Sl_Slicecst6, Sl_Slicecst7) -> Sl_output03 (name='Sl_Slice2')\n", + "+kr='Sl_output03': (3, 4, 1) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", + "Onnx-Unsqueeze(Sl_output03, Sl_Slicecst7) -> Un_expanded04 (name='Un_Unsqueeze2')\n", + "+kr='Un_expanded04': (3, 1, 4, 1) (dtype=float32 min=-1.5941405296325684 max=1.1006875038146973)\n", + "Onnx-MatMul(Un_expanded03, Un_expanded04) -> Ma_Y01 (name='Ma_MatMul')\n", + "+kr='Ma_Y01': (3, 2, 4, 1) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-Slice(Ma_Y01, Un_Unsqueezecst1, Sl_Slicecst9, Sl_Slicecst7) -> Sl_output02 (name='Sl_Slice3')\n", + "+kr='Sl_output02': (3, 2, 4, 1) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-Transpose(Sl_output02) -> Tr_transposed01 (name='Tr_Transpose1')\n", + "+kr='Tr_transposed01': (2, 3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-Gather(Tr_transposed01, Ga_Gathercst1) -> Ga_output0 (name='Ga_Gather1')\n", + "+kr='Ga_output0': (3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-Slice(Ga_output0, Un_Unsqueezecst1, Sl_Slicecst7, Sl_Slicecst7) -> Sl_output01 (name='Sl_Slice4')\n", + "+kr='Sl_output01': (3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-Unsqueeze(Sl_output01, Sl_Slicecst7) -> Un_expanded02 (name='Un_Unsqueeze3')\n", + "+kr='Un_expanded02': (3, 1, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-MatMul(Un_expanded0, Un_expanded02) -> Ma_Y0 (name='Ma_MatMul1')\n", + "+kr='Ma_Y0': (3, 2, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-Transpose(Ma_Y0) -> Tr_transposed0 (name='Tr_Transpose2')\n", + "+kr='Tr_transposed0': (2, 3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-Gather(Tr_transposed01, Ga_Gathercst2) -> Ga_output03 (name='Ga_Gather2')\n", + "+kr='Ga_output03': (3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", + "Onnx-Slice(Ga_output03, Un_Unsqueezecst1, Sl_Slicecst7, Sl_Slicecst7) -> Sl_output07 (name='Sl_Slice5')\n", + "+kr='Sl_output07': (3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", + "Onnx-Unsqueeze(Sl_output07, Sl_Slicecst7) -> Un_expanded06 (name='Un_Unsqueeze5')\n", + "+kr='Un_expanded06': (3, 1, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", + "Onnx-MatMul(Un_expanded0, Un_expanded06) -> Ma_Y03 (name='Ma_MatMul2')\n", + "+kr='Ma_Y03': (3, 2, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", + "Onnx-Transpose(Ma_Y03) -> Tr_transposed04 (name='Tr_Transpose3')\n", + "+kr='Tr_transposed04': (2, 3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", + "Onnx-Slice(Tr_transposed04, Sl_Slicecst7, Sl_Slicecst18, Un_Unsqueezecst1) -> Sl_output06 (name='Sl_Slice6')\n", + "+kr='Sl_output06': (1, 3, 1, 4) (dtype=float32 min=0.0 max=0.0)\n", + "Onnx-Neg(Sl_output06) -> Ne_Y0 (name='Ne_Neg')\n", + "+kr='Ne_Y0': (1, 3, 1, 4) (dtype=float32 min=-0.0 max=-0.0)\n", + "Onnx-Slice(Tr_transposed04, Un_Unsqueezecst1, Sl_Slicecst7, Un_Unsqueezecst1) -> Sl_output08 (name='Sl_Slice7')\n", + "+kr='Sl_output08': (1, 3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", + "Onnx-Concat(Ne_Y0, Sl_output08) -> Co_concat_result02 (name='Co_Concat1')\n", + "+kr='Co_concat_result02': (2, 3, 1, 4) (dtype=float32 min=-2.1299846172332764 max=2.1299846172332764)\n", + "Onnx-Add(Tr_transposed0, Co_concat_result02) -> Ad_C0 (name='Ad_Add')\n", + "+kr='Ad_C0': (2, 3, 1, 4) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-Slice(Ad_C0, Sl_Slicecst2, Sl_Slicecst24, Sl_Slicecst25) -> Sl_output0 (name='Sl_Slice8')\n", + "+kr='Sl_output0': (2, 3, 1, 3) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n", + "Onnx-Slice(Sh_shape0, Un_Unsqueezecst1, Sl_Slicecst, Un_Unsqueezecst1) -> Sl_output010 (name='Sl_Slice9')\n", + "+kr='Sl_output010': (1,) (dtype=int64 min=3 max=3)\n", + "Onnx-Shape(Sl_output0) -> Sh_shape03 (name='Sh_Shape3')\n", + "+kr='Sh_shape03': (4,) (dtype=int64 min=1 max=3)\n", + "Onnx-Shape(Sh_shape03) -> Sh_shape04 (name='Sh_Shape4')\n", + "+kr='Sh_shape04': (1,) (dtype=int64 min=4 max=4)\n", + "Onnx-Gather(Sh_shape04, Un_Unsqueezecst1) -> Ga_output04 (name='Ga_Gather3')\n", + "+kr='Ga_output04': (1,) (dtype=int64 min=4 max=4)\n", + "Onnx-Slice(Sh_shape03, Sl_Slicecst, Ga_output04, Un_Unsqueezecst1) -> Sl_output012 (name='Sl_Slice10')\n", + "+kr='Sl_output012': (2,) (dtype=int64 min=1 max=3)\n", + "Onnx-Concat(Sl_Slicecst18, Sl_output010, Sl_output012) -> Co_concat_result03 (name='Co_Concat2')\n", + "+kr='Co_concat_result03': (4,) (dtype=int64 min=1 max=3)\n", + "Onnx-Reshape(Sl_output0, Co_concat_result03) -> y (name='Re_Reshape1')\n", + "+kr='y': (2, 3, 1, 3) (dtype=float32 min=-2.508474588394165 max=3.18086314201355)\n" + ] + }, + { + "data": { + "text/plain": [ + "{'y': array([[[[ 1.0642704e+00, 7.8808188e-02, 3.1808631e+00]],\n", + " \n", + " [[-1.7878022e+00, -2.5084746e+00, 5.4854429e-01]],\n", + " \n", + " [[-2.2876425e+00, 8.1763226e-01, 4.4160408e-01]]],\n", + " \n", + " \n", + " [[[ 0.0000000e+00, -2.1299846e+00, 7.7034396e-16]],\n", + " \n", + " [[ 0.0000000e+00, 1.2344277e-01, 5.0231944e-16]],\n", + " \n", + " [[ 0.0000000e+00, 1.0373981e+00, -5.9766380e-18]]]],\n", + " dtype=float32)}" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from mlprodict.onnxrt import OnnxInference\n", + "\n", + "x = numpy.random.randn(3, 1, 4).astype(numpy.float32)\n", + "onx = onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_\n", + "oinf = OnnxInference(onx)\n", + "oinf.run({'x': x}, verbose=1, fLOG=print)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "3843308e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" } - ], - "source": [ - "from mlprodict.onnxrt import OnnxInference\n", - "\n", - "x = numpy.random.randn(3, 1, 4).astype(numpy.float32)\n", - "onx = onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_\n", - "oinf = OnnxInference(onx)\n", - "oinf.run({'x': x}, verbose=1, fLOG=print)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3843308e", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/_doc/sphinxdoc/source/api/sklearn.rst b/_doc/sphinxdoc/source/api/sklearn.rst deleted file mode 100644 index 8b1378917..000000000 --- a/_doc/sphinxdoc/source/api/sklearn.rst +++ /dev/null @@ -1 +0,0 @@ - diff --git a/_unittests/ut_onnx_conv/test_lightgbm_tree_structure.py b/_unittests/ut_onnx_conv/test_lightgbm_tree_structure.py index 5fa683a0b..fba63f33b 100644 --- a/_unittests/ut_onnx_conv/test_lightgbm_tree_structure.py +++ b/_unittests/ut_onnx_conv/test_lightgbm_tree_structure.py @@ -4,12 +4,23 @@ import unittest from logging import getLogger import copy +import json +import base64 +import lzma import numpy from pandas import DataFrame from pyquickhelper.pycode import ExtTestCase + +try: + from pyquickhelper.pycode.unittest_cst import decompress_cst +except ImportError: + decompress_cst = lambda d: json.loads( + lzma.decompress(base64.b64decode(b"".join(d)))) + from skl2onnx.common.data_types import FloatTensorType from sklearn.datasets import load_iris -from mlprodict.onnx_conv.operator_converters.conv_lightgbm import modify_tree_for_rule_in_set +from mlprodict.onnx_conv.helpers.lgbm_helper import ( + modify_tree_for_rule_in_set, restore_lgbm_info) from mlprodict.onnx_conv.parsers.parse_lightgbm import MockWrappedLightGbmBoosterClassifier from mlprodict.onnx_conv import register_converters, to_onnx from mlprodict.onnxrt import OnnxInference @@ -30,6 +41,30 @@ def count_nodes(tree, done=None): return nb +def clean_tree(tree): + def walk_through(tree): + if 'tree_structure' in tree: + for w in walk_through(tree['tree_structure']): + yield w + yield tree + if 'left_child' in tree: + for w in walk_through(tree['left_child']): + yield w + if 'right_child' in tree: + for w in walk_through(tree['right_child']): + yield w + + nodes = list(walk_through(tree3)) + for node in nodes: + for k in ['split_gain', 'split_feature', 'split_index', 'leaf_count', + 'internal_value', 'internal_weight', 'internal_count', 'leaf_weight']: + if k in node: + del node[k] + for k in ['leaf_value', 'leaf_value']: + if k in node: + node[k] = 0 + + tree2 = {'average_output': False, 'feature_names': ['c1', 'c2', 'c3', 'c4'], 'label_index': 0, @@ -140,6 +175,29 @@ def count_nodes(tree, done=None): 'version': 'v2'} +# This constant by built by appling function pyquickhelper.pycode.unittest_cst.compress_cst. + +tree3 = decompress_cst([ + b'/Td6WFoAAATm1rRGAgAhARYAAAB0L+Wj4Ck9A2tdAD2IiodjqVNsvcJJI6C9h2Y0CbG5b7', + b'OaqsqxvLBzg7BltxogYoUzxj35qbUETbBAyJeMccezEDeIKOT1GB+I50txUuc8zkWDcp/n', + b'kx2YhORZxAyj55pXJF/xW5aySLknuTn/5cRfSL9AGF7dHdW9k8RqP5GONWx3YvvnP0tCW0', + b'lGKd5caxoNFaB5tg+je6f0s6N6QQo8wqrBPtjJ7bQf50vFrpYgkQNAEZIVutpzgE9c4o1L', + b'Uv/vJgnhQXOpk/4hOCV2q8VG+jD9oIjPINOOZ642k2QmsdWC+l3XagJnbN9dqT/4C9ehfM', + b'nf6Bw5XcRXD4rtmOyUq/ocuh1WfPinlKd/Jn0YOydq1FpH+VNSUjjPUGJbJal4Pa6jcx/Y', + b'9mcBjp9kP1pM5wkCJ52Kv12UQ/+2j+n0rUQbbqs10iFJo4h4KB/Ie/bugBLNItmNhNhDP4', + b'36Q6jCsLsXlu0gTiWZfGQapR+DJIsVKHh9GeagotXpHTwYX72KrTFwIdxgf9Y2X1EUqiJV', + b'wXdP7GprCs9QsIvCkqW59hPNStt2tyWtlSsXsnjU5e0Jn3USVHOcbwCBSpCtFlpg8tiS9m', + b'Zv1TIGj9cvEk1Ke9p6bZelvtXqHJRISJ8fCVjrqTnEjyUdPaG1wmqCyz7NFEkngrBinY7e', + b'ZMHmO1y6IhLI1zN0kq8zBHIQeqUruYgBatPI6jI585wQ6mYCobgQc7B6Ae6XlgOthATrr2', + b'oDdnIeAPeUKVMXPIq9NnwlwsyNEoTddI42NiMde8jVzVm4wwwnqrmbKlJsi5LJhRQlaEFX', + b'etzNn7llkCSwv88gYhcaDWP3Ewchse2iQDkJ0dPZhx0FB18X6wvEcwkt/H+dzTgAYOCSkr', + b'T3thNkPCvQ4keiRzHiWNzLc+NAhz5NX8BXsVQFkEyf4oUkKHjy053LBmXpHM75LBhdJmFH', + b'vqRENHF6QgiPLAjc/1NHatYLcY0VRetr55Bp2jWU+z75P2TrMkTHFnjbOEQ3p13USzVmnq', + b'3d0EUvp5Q5dUPDFAIhkH+oUkgK4lX2xlyEGh+23EqQtmkjOyKj7HPHoPZo2AjASlRTc78u', + b'1c9nWkTbwBGbZUsMmWzyjbDe/h2Yi2GvkSkIh8UKtYDlTzpT62G9Chf5N9HEfFjQWcdCEi', + b'7Y3Hx86ee03jpP42ssAADRqUIMvx3yYwABhwe+UgAA2u9V4LHEZ/sCAAAAAARZWg==']) + + class TestLightGbmTreeStructure(ExtTestCase): def setUp(self): @@ -223,7 +281,6 @@ def test_onnxrt_python_lightgbm_categorical2(self): self.assertEqual(nb2, 18) def test_mock_lightgbm(self): - tree = copy.deepcopy(tree2) nb1 = sum(count_nodes(t['tree_structure']) for t in tree['tree_info']) model = MockWrappedLightGbmBoosterClassifier(tree) @@ -258,6 +315,30 @@ def test_mock_lightgbm(self): prob = DataFrame(pred["output_probability"]).values self.assertEqual(prob.shape, (row, 2)) + def test_mock_lightgbm_info(self): + tree = copy.deepcopy(tree3) + info = restore_lgbm_info(tree) + modify_tree_for_rule_in_set(tree, info=info) + expected = tree + tree = copy.deepcopy(tree3) + info = restore_lgbm_info(tree) + modify_tree_for_rule_in_set(tree, info=info) + self.assertEqual(expected, tree) + + def test_mock_lightgbm_profile(self): + tree = copy.deepcopy(tree3) + info = restore_lgbm_info(tree) + self.assertIsInstance(info, list) + self.assertGreater(len(info), 1) + + def g(): + for _ in range(0, 100): + modify_tree_for_rule_in_set(tree, info=info) + p2 = self.profile(g)[1] + self.assertIn('cumtime', p2) + if __name__ == "__main__": + print(p2) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py index c0db8a11f..9ee1a07e6 100644 --- a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py +++ b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm.py @@ -1,5 +1,5 @@ """ -@brief test log(time=3s) +@brief test log(time=6s) """ import sys import unittest @@ -200,6 +200,116 @@ def test_onnxrt_python_lightgbm_categorical_iris(self): values = pandas.DataFrame(got['output_probability']).values self.assertEqualArray(exp, values[:, 1], decimal=5) + @skipif_circleci('stuck') + @unittest.skipIf(sys.platform == 'darwin', 'stuck') + @ignore_warnings((RuntimeWarning, UserWarning)) + def test_onnxrt_python_lightgbm_categorical_iris_booster3(self): + from lightgbm import LGBMClassifier, Dataset, train as lgb_train + + iris = load_iris() + X, y = iris.data, iris.target + X = (X * 10).astype(numpy.int32) + X_train, X_test, y_train, _ = train_test_split( + X, y, random_state=11) + other_x = numpy.random.randint( + 0, high=10, size=(1500, X_train.shape[1])) + X_train = numpy.vstack([X_train, other_x]).astype(dtype=numpy.int32) + y_train = numpy.hstack( + [y_train, numpy.zeros(500) + 3, numpy.zeros(500) + 4, + numpy.zeros(500) + 5]).astype(dtype=numpy.int32) + self.assertEqual(y_train.shape, (X_train.shape[0], )) + + # Classic + gbm = LGBMClassifier() + gbm.fit(X_train, y_train) + exp = gbm.predict_proba(X_test) + onx = to_onnx(gbm, initial_types=[ + ('X', Int64TensorType([None, X_train.shape[1]]))]) + self.assertIn('ZipMap', str(onx)) + oif = OnnxInference(onx) + got = oif.run({'X': X_test}) + values = pandas.DataFrame(got['output_probability']).values + self.assertEqualArray(exp, values, decimal=5) + + # categorical_feature=[0, 1] + train_data = Dataset( + X_train, label=y_train, + feature_name=['c1', 'c2', 'c3', 'c4'], + categorical_feature=['c1', 'c2']) + + params = { + "boosting_type": "gbdt", + "learning_rate": 0.05, + "n_estimators": 2, + "objective": "binary", + "max_bin": 5, + "min_child_samples": 100, + 'verbose': -1, + } + + booster = lgb_train(params, train_data) + exp = booster.predict(X_test) + + onx = to_onnx(booster, initial_types=[ + ('X', Int64TensorType([None, X_train.shape[1]]))]) + self.assertIn('ZipMap', str(onx)) + oif = OnnxInference(onx) + got = oif.run({'X': X_test}) + values = pandas.DataFrame(got['output_probability']).values + self.assertEqualArray(exp, values[:, 1], decimal=5) + + @skipif_circleci('stuck') + @unittest.skipIf(sys.platform == 'darwin', 'stuck') + @ignore_warnings((RuntimeWarning, UserWarning)) + def test_onnxrt_python_lightgbm_categorical_iris_booster3_real(self): + from lightgbm import LGBMClassifier, Dataset, train as lgb_train + + iris = load_iris() + X, y = iris.data, iris.target + X = (X * 10).astype(numpy.float32) + X_train, X_test, y_train, _ = train_test_split( + X, y, random_state=11) + + # Classic + gbm = LGBMClassifier() + gbm.fit(X_train, y_train) + exp = gbm.predict_proba(X_test) + onx = to_onnx(gbm.booster_, initial_types=[ + ('X', FloatTensorType([None, X_train.shape[1]]))]) + self.assertIn('ZipMap', str(onx)) + oif = OnnxInference(onx) + got = oif.run({'X': X_test}) + values = pandas.DataFrame(got['output_probability']).values + self.assertEqualArray(exp, values, decimal=5) + + # categorical_feature=[0, 1] + train_data = Dataset( + X_train, label=y_train, + feature_name=['c1', 'c2', 'c3', 'c4'], + categorical_feature=['c1', 'c2']) + + params = { + "boosting_type": "gbdt", + "learning_rate": 0.05, + "n_estimators": 2, + "objective": "multiclass", + "max_bin": 5, + "min_child_samples": 100, + 'verbose': -1, + 'num_class': 3, + } + + booster = lgb_train(params, train_data) + exp = booster.predict(X_test) + + onx = to_onnx(booster, initial_types=[ + ('X', FloatTensorType([None, X_train.shape[1]]))]) + self.assertIn('ZipMap', str(onx)) + oif = OnnxInference(onx) + got = oif.run({'X': X_test}) + values = pandas.DataFrame(got['output_probability']).values + self.assertEqualArray(exp, values, decimal=5) + @skipif_circleci('stuck') @unittest.skipIf(sys.platform == 'darwin', 'stuck') @ignore_warnings((RuntimeWarning, UserWarning)) @@ -286,9 +396,10 @@ def test_lightgbm_booster_classifier(self): 'subsample_freq': 1, 'bagging_fraction': 0.5, 'feature_fraction': 0.5}, data) - model_onnx = to_onnx(model, X) + model_onnx = to_onnx(model, X, verbose=2, rewrite_ops=True) self.assertNotEmpty(model_onnx) if __name__ == "__main__": + # TestOnnxrtRuntimeLightGbm().test_lightgbm_booster_classifier() unittest.main() diff --git a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py index 04110dcaa..5417fc202 100644 --- a/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py +++ b/_unittests/ut_onnx_conv/test_onnxrt_runtime_lightgbm_bug.py @@ -1,5 +1,5 @@ """ -@brief test log(time=3s) +@brief test log(time=8s) """ import sys import unittest @@ -17,29 +17,35 @@ def setUp(self): logger = getLogger('skl2onnx') logger.disabled = True register_converters() + X = numpy.abs(numpy.random.randn(10, 200)).astype(numpy.float32) + for i in range(X.shape[1]): + X[:, i] *= (i + 1) * 10 + y = X.sum(axis=1) / 1e3 + numpy.random.randn( + X.shape[0]).astype(numpy.float32) + X = X.astype(numpy.float32) + y = y.astype(numpy.float32) + self.data_X, self.data_y = X, y @skipif_circleci('stuck') @unittest.skipIf(sys.platform == 'darwin', 'stuck') - def test_lightgbm_regressor(self): - from lightgbm import LGBMRegressor + def test_xgboost_regressor(self): + from xgboost import XGBRegressor try: - from onnxmltools.convert import convert_lightgbm + from onnxmltools.convert import convert_xgboost except ImportError: - convert_lightgbm = None + convert_xgboost = None - X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) - y = X.sum(axis=1) + numpy.random.randn( - X.shape[0]).astype(numpy.float32) / 10 - model = LGBMRegressor( - max_depth=8, n_estimators=100, min_child_samples=1, - learning_rate=0.0000001) + X, y = self.data_X, self.data_y + model = XGBRegressor( + max_depth=8, n_estimators=100, + learning_rate=0.000001) model.fit(X, y) expected = model.predict(X) model_onnx = to_onnx(model, X) - if convert_lightgbm is not None: - model_onnx2 = convert_lightgbm( - model, initial_types=[('X', FloatTensorType([None, 227]))]) + if convert_xgboost is not None: + model_onnx2 = convert_xgboost( + model, initial_types=[('X', FloatTensorType([None, X.shape[1]]))]) else: model_onnx2 = None @@ -52,82 +58,113 @@ def test_lightgbm_regressor(self): got = oinf.run({'X': X})['variable'] diff = numpy.abs(got.ravel() - expected.ravel()).max() if __name__ == "__main__": - print("lgb", i, rt, diff) - self.assertLess(diff, 1e-3) + print("xgb32", "mlprod" if i == + 0 else "mltool", rt, diff) + self.assertLess(diff, 1e-5) @skipif_circleci('stuck') @unittest.skipIf(sys.platform == 'darwin', 'stuck') - def test_lightgbm_regressor_double(self): + def test_missing_values(self): from lightgbm import LGBMRegressor - - X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) - y = X.sum(axis=1) + numpy.random.randn( - X.shape[0]).astype(numpy.float32) / 10 - model = LGBMRegressor( - max_depth=8, n_estimators=100, min_child_samples=1, - learning_rate=0.0000001) - model.fit(X, y) - expected = model.predict(X) - - model_onnx = to_onnx(model, X, rewrite_ops=True) - model_onnx2 = to_onnx(model, X.astype(numpy.float64), - rewrite_ops=True) - - for i, mo in enumerate([model_onnx, model_onnx2]): - for rt in ['python', 'onnxruntime1']: - if "TreeEnsembleRegressorDouble" in str(mo): - x = X.astype(numpy.float64) - if rt == 'onnxruntime1': - continue - else: - x = X - with self.subTest(i=i, rt=rt): - oinf = OnnxInference(mo, runtime=rt) - got = oinf.run({'X': x})['variable'] - diff = numpy.abs(got.ravel() - expected.ravel()).max() - if __name__ == "__main__": - print("lgbd", i, rt, diff) - if i == 1 and rt == 'python': - self.assertLess(diff, 1e-5) - else: - self.assertLess(diff, 1e-3) + regressor = LGBMRegressor( + objective="regression", min_data_in_bin=1, min_data_in_leaf=1, + n_estimators=1, learning_rate=1) + + y = numpy.array([0, 0, 1, 1, 1]) + X_train = numpy.array( + [[1.0, 0.0], [1.0, -1.0], + [1.0, -1.0], [2.0, -1.0], [2.0, -1.0]], + dtype=numpy.float32) + X_test = numpy.array([[1.0, numpy.nan]], dtype=numpy.float32) + + regressor.fit(X_train, y) + model_onnx = to_onnx(regressor, X_train[:1]) + y_pred = regressor.predict(X_test) + oinf = OnnxInference(model_onnx) + y_pred_onnx = oinf.run({"X": X_test})['variable'] + self.assertEqualArray(y_pred, y_pred_onnx, decimal=4) @skipif_circleci('stuck') @unittest.skipIf(sys.platform == 'darwin', 'stuck') - def test_xgboost_regressor(self): - from xgboost import XGBRegressor + def test_lightgbm_regressor(self): + from lightgbm import LGBMRegressor try: - from onnxmltools.convert import convert_xgboost + from onnxmltools.convert import convert_lightgbm except ImportError: - convert_xgboost = None + convert_lightgbm = None + X, y = self.data_X, self.data_y + + for ne in [1, 2, 10, 50, 100, 200]: + for mx in [1, 10]: + if __name__ != "__main__" and mx > 5: + break + model = LGBMRegressor( + max_depth=mx, n_estimators=ne, min_child_samples=1, + learning_rate=0.0000001) + model.fit(X, y) + expected = model.predict(X) + + model_onnx = to_onnx(model, X) + if convert_lightgbm is not None: + model_onnx2 = convert_lightgbm( + model, initial_types=[('X', FloatTensorType([None, X.shape[1]]))]) + else: + model_onnx2 = None - X = numpy.abs(numpy.random.randn(7, 227)).astype(numpy.float32) - y = X.sum(axis=1) + numpy.random.randn( - X.shape[0]).astype(numpy.float32) / 10 - model = XGBRegressor( - max_depth=8, n_estimators=100, - learning_rate=0.000001) - model.fit(X, y) - expected = model.predict(X) + for i, mo in enumerate([model_onnx, model_onnx2]): + if mo is None: + continue + for rt in ['python', 'onnxruntime1']: + with self.subTest(i=i, rt=rt, max_depth=mx, n_est=ne): + oinf = OnnxInference(mo, runtime=rt) + got = oinf.run({'X': X})['variable'] + diff = numpy.abs( + got.ravel() - expected.ravel()).max() + if __name__ == "__main__": + print("lgb1 mx=%d ne=%d" % (mx, ne), + "mlprod" if i == 0 else "mltool", rt[:6], diff) + self.assertLess(diff, 1e-3) - model_onnx = to_onnx(model, X) - if convert_xgboost is not None: - model_onnx2 = convert_xgboost( - model, initial_types=[('X', FloatTensorType([None, 227]))]) - else: - model_onnx2 = None + @skipif_circleci('stuck') + @unittest.skipIf(sys.platform == 'darwin', 'stuck') + def test_lightgbm_regressor_double(self): + from lightgbm import LGBMRegressor - for i, mo in enumerate([model_onnx, model_onnx2]): - if mo is None: - continue - for rt in ['python', 'onnxruntime1']: - with self.subTest(i=i, rt=rt): - oinf = OnnxInference(mo, runtime=rt) - got = oinf.run({'X': X})['variable'] - diff = numpy.abs(got.ravel() - expected.ravel()).max() - if __name__ == "__main__": - print("xgb", i, rt, diff) - self.assertLess(diff, 1e-5) + X, y = self.data_X, self.data_y + + for ne in [1, 2, 10, 50, 100, 200]: # pylint: disable=R1702 + for mx in [1, 10]: + if __name__ != "__main__" and mx > 5: + break + model = LGBMRegressor( + max_depth=mx, n_estimators=ne, min_child_samples=1, + learning_rate=0.0000001) + model.fit(X, y) + expected = model.predict(X) + model_onnx = to_onnx(model, X, rewrite_ops=True) + model_onnx2 = to_onnx(model, X.astype(numpy.float64), + rewrite_ops=True) + + for i, mo in enumerate([model_onnx, model_onnx2]): + for rt in ['python', 'onnxruntime1']: + if "TreeEnsembleRegressorDouble" in str(mo): + x = X.astype(numpy.float64) + if rt == 'onnxruntime1': + continue + else: + x = X + with self.subTest(i=i, rt=rt, max_depth=mx, n_est=ne): + oinf = OnnxInference(mo, runtime=rt) + got = oinf.run({'X': x})['variable'] + diff = numpy.abs( + got.ravel() - expected.ravel()).max() + if __name__ == "__main__": + print("lgb2 mx=%d ne=%d" % (mx, ne), + i * 32 + 32, rt[:6], diff) + if i == 1 and rt == 'python': + self.assertLess(diff, 1e-5) + else: + self.assertLess(diff, 1e-3) if __name__ == "__main__": diff --git a/_unittests/ut_tools/test_graphs.py b/_unittests/ut_tools/test_graphs.py index 401eb7508..52472b20f 100644 --- a/_unittests/ut_tools/test_graphs.py +++ b/_unittests/ut_tools/test_graphs.py @@ -106,6 +106,17 @@ def test_bug_graph_infinite(self): text = oinf.to_text(distance=8) self.assertIn("slice_end", text) + def test_pipe_graph_simplified(self): + model = self.fit( + make_pipeline(StandardScaler(), LogisticRegression())) + onx = to_onnx(model, numpy.zeros((3, 4), dtype=numpy.float64)) + bigraph = onnx2bigraph(onx, graph_type='simplified') + text = str(bigraph) + self.assertEqual(text, "BiGraph(19 v., 12 v., 30 edges)") + disp = bigraph.summarize() + self.assertIn("B('Cast', '042434366f', 'Cast1')", disp) + self.assertIn("B('Div', '', 'Di_Div'", disp) + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/onnx_conv/convert.py b/mlprodict/onnx_conv/convert.py index 681795cb3..ec7416231 100644 --- a/mlprodict/onnx_conv/convert.py +++ b/mlprodict/onnx_conv/convert.py @@ -29,7 +29,8 @@ def convert_scorer(fct, initial_types, name=None, custom_conversion_functions=None, custom_shape_calculators=None, custom_parsers=None, white_op=None, - black_op=None, final_types=None): + black_op=None, final_types=None, + verbose=0): """ Converts a scorer into :epkg:`ONNX` assuming there exists a converter associated to it. @@ -37,36 +38,33 @@ def convert_scorer(fct, initial_types, name=None, transformer, then calls function *convert_sklearn* from :epkg:`sklearn-onnx`. - @param fct function to convert (or a scorer from - :epkg:`scikit-learn`) - @param initial_types types information - @param name name of the produced model - @param target_opset to do it with a different target opset - @param options additional parameters for the conversion - @param custom_conversion_functions a dictionary for specifying the user customized - conversion function, it takes precedence over - registered converters - @param custom_shape_calculators a dictionary for specifying the user - customized shape calculator - it takes precedence over registered - shape calculators. - @param custom_parsers parsers determine which outputs is expected - for which particular task, default parsers are - defined for classifiers, regressors, pipeline but - they can be rewritten, *custom_parsers* is a dictionary - ``{ type: fct_parser(scope, model, inputs, - custom_parsers=None) }`` - @param white_op white list of ONNX nodes allowed - while converting a pipeline, if empty, - all are allowed - @param black_op black list of ONNX nodes allowed - while converting a pipeline, if empty, - none are blacklisted - @param final_types a python list. Works the same way as - initial_types but not mandatory, it is used - to overwrites the type (if type is not None) - and the name of every output. - @return :epkg:`ONNX` graph + :param fct: function to convert (or a scorer from :epkg:`scikit-learn`) + :param initial_types: types information + :param name: name of the produced model + :param target_opset: to do it with a different target opset + :param options: additional parameters for the conversion + :param custom_conversion_functions: a dictionary for specifying the user + customized conversion function, it takes precedence over + registered converters + :param custom_shape_calculators: a dictionary for specifying the user + customized shape calculator it takes precedence over registered + shape calculators. + :param custom_parsers: parsers determine which outputs is expected + for which particular task, default parsers are + defined for classifiers, regressors, pipeline but + they can be rewritten, *custom_parsers* is a dictionary + ``{ type: fct_parser(scope, model, inputs, + custom_parsers=None) }`` + :param white_op: white list of ONNX nodes allowed + while converting a pipeline, if empty, all are allowed + :param black_op: black list of ONNX nodes allowed + while converting a pipeline, if empty, none are blacklisted + :param final_types: a python list. Works the same way as + initial_types but not mandatory, it is used + to overwrites the type (if type is not None) + and the name of every output. + :param verbose: displays information while converting + :return: :epkg:`ONNX` graph """ if hasattr(fct, '_score_func'): kwargs = fct._kwargs @@ -82,7 +80,8 @@ def convert_scorer(fct, initial_types, name=None, custom_conversion_functions=custom_conversion_functions, custom_shape_calculators=custom_shape_calculators, custom_parsers=custom_parsers, white_op=white_op, - black_op=black_op, final_types=final_types) + black_op=black_op, final_types=final_types, + verbose=verbose) def guess_initial_types(X, initial_types): @@ -243,34 +242,35 @@ def guess_schema_from_model(model, tensor_type=None, schema=None): def to_onnx(model, X=None, name=None, initial_types=None, target_opset=None, options=None, rewrite_ops=False, - white_op=None, black_op=None, final_types=None): + white_op=None, black_op=None, final_types=None, + verbose=0): """ Converts a model using on :epkg:`sklearn-onnx`. - @param model model to convert or a function - wrapped into :epkg:`_PredictScorer` with - function :epkg:`make_scorer` - @param X training set (at least one row), - can be None, it is used to infered the - input types (*initial_types*) - @param initial_types if *X* is None, then *initial_types* must be - defined - @param name name of the produced model - @param target_opset to do it with a different target opset - @param options additional parameters for the conversion - @param rewrite_ops rewrites some existing converters, - the changes are permanent - @param white_op white list of ONNX nodes allowed - while converting a pipeline, if empty, - all are allowed - @param black_op black list of ONNX nodes allowed - while converting a pipeline, if empty, - none are blacklisted - @param final_types a python list. Works the same way as - initial_types but not mandatory, it is used - to overwrites the type (if type is not None) - and the name of every output. - @return converted model + :param model: model to convert or a function + wrapped into :epkg:`_PredictScorer` with + function :epkg:`make_scorer` + :param X: training set (at least one row), + can be None, it is used to infered the + input types (*initial_types*) + :param initial_types: if *X* is None, then *initial_types* + must be defined + :param name: name of the produced model + :param target_opset: to do it with a different target opset + :param options: additional parameters for the conversion + :param rewrite_ops: rewrites some existing converters, + the changes are permanent + :param white_op: white list of ONNX nodes allowed + while converting a pipeline, if empty, all are allowed + :param black_op: black list of ONNX nodes allowed + while converting a pipeline, if empty, + none are blacklisted + :param final_types: a python list. Works the same way as + initial_types but not mandatory, it is used + to overwrites the type (if type is not None) + and the name of every output. + :param verbose: display information while converting the model + :return: converted model The function rewrites function *to_onnx* from :epkg:`sklearn-onnx` but may changes a few converters if *rewrite_ops* is True. @@ -356,7 +356,8 @@ def to_onnx(model, X=None, name=None, initial_types=None, type(model))) return model.to_onnx( X=X, name=name, options=options, black_op=black_op, - white_op=white_op, final_types=final_types) + white_op=white_op, final_types=final_types, + verbose=verbose) if rewrite_ops: old_values, old_shapes = register_rewritten_operators() @@ -422,7 +423,7 @@ def _guess_type_(X, itype, dtype): res = convert_scorer(model, initial_types, name=name, target_opset=target_opset, options=options, black_op=black_op, white_op=white_op, - final_types=final_types) + final_types=final_types, verbose=verbose) else: if name is None: name = "mlprodict_ONNX(%s)" % model.__class__.__name__ @@ -431,7 +432,7 @@ def _guess_type_(X, itype, dtype): res = convert_sklearn(model, initial_types=initial_types, name=name, target_opset=target_opset, options=options, black_op=black_op, white_op=white_op, - final_types=final_types) + final_types=final_types, verbose=verbose) register_rewritten_operators(old_values, old_shapes) return res diff --git a/mlprodict/onnx_conv/helpers/__init__.py b/mlprodict/onnx_conv/helpers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mlprodict/onnx_conv/helpers/lgbm_helper.py b/mlprodict/onnx_conv/helpers/lgbm_helper.py new file mode 100644 index 000000000..4449de818 --- /dev/null +++ b/mlprodict/onnx_conv/helpers/lgbm_helper.py @@ -0,0 +1,366 @@ +""" +@file +@brief Helpers to speed up the conversion of Lightgbm models or transform it. +""" +from collections import deque +import ctypes +import json +import re + + +def restore_lgbm_info(tree): + """ + Restores speed up information to help + modifying the structure of the tree. + """ + + def walk_through(t): + if 'tree_info' in t: + yield None + elif 'tree_structure' in t: + for w in walk_through(t['tree_structure']): + yield w + else: + yield t + if 'left_child' in t: + for w in walk_through(t['left_child']): + yield w + if 'right_child' in t: + for w in walk_through(t['right_child']): + yield w + + nodes = [] + if 'tree_info' in tree: + for node in walk_through(tree): + if node is None: + nodes.append([]) + elif 'right_child' in node or 'left_child' in node: + nodes[-1].append(node) + else: + for node in walk_through(tree): + if 'right_child' in node or 'left_child' in node: + nodes.append(node) + return nodes + + +def dump_booster_model(self, num_iteration=None, start_iteration=0, + importance_type='split', verbose=0): + """ + Dumps Booster to JSON format. + + Parameters + ---------- + self: booster + num_iteration : int or None, optional (default=None) + Index of the iteration that should be dumped. + If None, if the best iteration exists, it is dumped; otherwise, + all iterations are dumped. + If <= 0, all iterations are dumped. + start_iteration : int, optional (default=0) + Start index of the iteration that should be dumped. + importance_type : string, optional (default="split") + What type of feature importance should be dumped. + If "split", result contains numbers of times the feature is used in a model. + If "gain", result contains total gains of splits which use the feature. + verbose: dispays progress (usefull for big trees) + + Returns + ------- + json_repr : dict + JSON format of Booster. + + .. note:: + This function is inspired from + the :epkg:`lightgbm` (`dump_model + `_. + It creates intermediate structure to speed up the conversion + into ONNX of such model. The function overwrites the + `json.load` to fastly extract nodes. + """ + if getattr(self, 'is_mock', False): + return self.dump_model(), None + from lightgbm.basic import ( + _LIB, FEATURE_IMPORTANCE_TYPE_MAPPER, _safe_call, + json_default_with_numpy) + if num_iteration is None: + num_iteration = self.best_iteration + importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type] + buffer_len = 1 << 20 + tmp_out_len = ctypes.c_int64(0) + string_buffer = ctypes.create_string_buffer(buffer_len) + ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) + if verbose >= 2: + print("[dump_booster_model] call CAPI: LGBM_BoosterDumpModel") + _safe_call(_LIB.LGBM_BoosterDumpModel( + self.handle, + ctypes.c_int(start_iteration), + ctypes.c_int(num_iteration), + ctypes.c_int(importance_type_int), + ctypes.c_int64(buffer_len), + ctypes.byref(tmp_out_len), + ptr_string_buffer)) + actual_len = tmp_out_len.value + # if buffer length is not long enough, reallocate a buffer + if actual_len > buffer_len: + string_buffer = ctypes.create_string_buffer(actual_len) + ptr_string_buffer = ctypes.c_char_p( + *[ctypes.addressof(string_buffer)]) + _safe_call(_LIB.LGBM_BoosterDumpModel( + self.handle, + ctypes.c_int(start_iteration), + ctypes.c_int(num_iteration), + ctypes.c_int(importance_type_int), + ctypes.c_int64(actual_len), + ctypes.byref(tmp_out_len), + ptr_string_buffer)) + + WHITESPACE = re.compile( + r'[ \t\n\r]*', re.VERBOSE | re.MULTILINE | re.DOTALL) + + class Hook(json.JSONDecoder): + """ + Keep track of the progress, stores a copy of all objects with + a decision into a different container in order to walk through + all nodes in a much faster way than going through the architecture. + """ + + def __init__(self, *args, info=None, n_trees=None, verbose=0, + **kwargs): + json.JSONDecoder.__init__( + self, object_hook=self.hook, *args, **kwargs) + self.nodes = [] + self.buffer = [] + self.info = info + self.n_trees = n_trees + self.verbose = verbose + self.stored = 0 + if verbose >= 2 and n_trees is not None: + from tqdm import tqdm + self.loop = tqdm(total=n_trees) + self.loop.set_description("dump_booster") + else: + self.loop = None + + def decode(self, s, _w=WHITESPACE.match): + return json.JSONDecoder.decode(self, s, _w=_w) + + def raw_decode(self, s, idx=0): + return json.JSONDecoder.raw_decode(self, s, idx=idx) + + def hook(self, obj): + """ + Hook called everytime a JSON object is created. + Keep track of the progress, stores a copy of all objects with + a decision into a different container. + """ + # Every obj goes through this function from the leaves to the root. + if 'tree_info' in obj: + self.info['decision_nodes'] = self.nodes + if self.n_trees is not None and len(self.nodes) != self.n_trees: + raise RuntimeError( + "Unexpected number of trees %d (expecting %d)." % ( + len(self.nodes), self.n_trees)) + self.nodes = [] + if self.loop is not None: + self.loop.close() + if 'tree_structure' in obj: + self.nodes.append(self.buffer) + if self.loop is not None: + self.loop.update(len(self.nodes)) + if len(self.nodes) % 10 == 0: + self.loop.set_description( + "dump_booster: %d/%d trees, %d nodes" % ( + len(self.nodes), self.n_trees, self.stored)) + self.buffer = [] + if "decision_type" in obj: + self.buffer.append(obj) + self.stored += 1 + return obj + + if verbose >= 2: + print("[dump_booster_model] to_json") + info = {} + ret = json.loads(string_buffer.value.decode('utf-8'), cls=Hook, + info=info, n_trees=self.num_trees(), verbose=verbose) + ret['pandas_categorical'] = json.loads( + json.dumps(self.pandas_categorical, + default=json_default_with_numpy)) + if verbose >= 2: + print("[dump_booster_model] end.") + return ret, info + + +def dump_lgbm_booster(booster, verbose=0): + """ + Dumps a Lightgbm booster into JSON. + + :param booster: Lightgbm booster + :param verbose: verbosity + :return: json, dictionary with more information + """ + js, info = dump_booster_model(booster, verbose=verbose) + return js, info + + +def modify_tree_for_rule_in_set(gbm, use_float=False, verbose=0, count=0, # pylint: disable=R1710 + info=None): + """ + LightGBM produces sometimes a tree with a node set + to use rule ``==`` to a set of values (= in set), + the values are separated by ``||``. + This function unfold theses nodes. + + :param gbm: a tree coming from lightgbm dump + :param use_float: use float otherwise int first + then float if it does not work + :param verbose: verbosity, use :epkg:`tqdm` to show progress + :param count: number of nodes already changed (origin) before this call + :param info: addition information to speed up this search + :return: number of changed nodes (include *count*) + + A child looks like the following: + + .. runpython:: + :showcode: + :warningout: DeprecationWarning + + import pprint + from mlprodict.onnx_conv.operator_converters.conv_lightgbm import modify_tree_for_rule_in_set + + tree = {'decision_type': '==', + 'default_left': True, + 'internal_count': 6805, + 'internal_value': 0.117558, + 'left_child': {'leaf_count': 4293, + 'leaf_index': 18, + 'leaf_value': 0.003519117642745049}, + 'missing_type': 'None', + 'right_child': {'leaf_count': 2512, + 'leaf_index': 25, + 'leaf_value': 0.012305307958365394}, + 'split_feature': 24, + 'split_gain': 12.233599662780762, + 'split_index': 24, + 'threshold': '10||12||13'} + + modify_tree_for_rule_in_set(tree) + + pprint.pprint(tree) + """ + if 'tree_info' in gbm: + if info is not None: + dec_nodes = info['decision_nodes'] + else: + dec_nodes = None + if verbose >= 2: + from tqdm import tqdm + loop = tqdm(gbm['tree_info']) + for i, tree in enumerate(loop): + loop.set_description("rules tree %d c=%d" % (i, count)) + count = modify_tree_for_rule_in_set( + tree, use_float=use_float, count=count, + info=None if dec_nodes is None else dec_nodes[i]) + else: + for i, tree in enumerate(gbm['tree_info']): + count = modify_tree_for_rule_in_set( + tree, use_float=use_float, count=count, + info=None if dec_nodes is None else dec_nodes[i]) + return count + + if 'tree_structure' in gbm: + return modify_tree_for_rule_in_set( + gbm['tree_structure'], use_float=use_float, count=count, + info=info) + + if 'decision_type' not in gbm: + return count + + def str2number(val): + if use_float: + return float(val) + else: + try: + return int(val) + except ValueError: # pragma: no cover + return float(val) + + if info is None: + + def recursive_call(this, c): + if 'left_child' in this: + c = process_node(this['left_child'], count=c) + if 'right_child' in this: + c = process_node(this['right_child'], count=c) + return c + + def process_node(node, count): + if 'decision_type' not in node: + return count + if node['decision_type'] != '==': + return recursive_call(node, count) + th = node['threshold'] + if not isinstance(th, str): + return recursive_call(node, count) + pos = th.find('||') + if pos == -1: + return recursive_call(node, count) + th1 = str2number(th[:pos]) + + def doit(): + rest = th[pos + 2:] + if '||' not in rest: + rest = str2number(rest) + + node['threshold'] = th1 + new_node = node.copy() + node['right_child'] = new_node + new_node['threshold'] = rest + + doit() + return recursive_call(node, count + 1) + + return process_node(gbm, count) + + # when info is used + + def split_node(node, th, pos): + th1 = str2number(th[:pos]) + + rest = th[pos + 2:] + if '||' not in rest: + rest = str2number(rest) + app = False + else: + app = True + + node['threshold'] = th1 + new_node = node.copy() + node['right_child'] = new_node + new_node['threshold'] = rest + return new_node, app + + stack = deque(info) + while len(stack) > 0: + node = stack.pop() + + if 'decision_type' not in node: + continue # leave + + if node['decision_type'] != '==': + continue + + th = node['threshold'] + if not isinstance(th, str): + continue + + pos = th.find('||') + if pos == -1: + continue + + new_node, app = split_node(node, th, pos) + count += 1 + if app: + stack.append(new_node) + + return count diff --git a/mlprodict/onnx_conv/operator_converters/conv_lightgbm.py b/mlprodict/onnx_conv/operator_converters/conv_lightgbm.py index 8ba0e380f..6541aaf4d 100644 --- a/mlprodict/onnx_conv/operator_converters/conv_lightgbm.py +++ b/mlprodict/onnx_conv/operator_converters/conv_lightgbm.py @@ -15,6 +15,9 @@ calculate_linear_regressor_output_shapes, calculate_linear_classifier_output_shapes) from skl2onnx.common.data_types import guess_numpy_type +from skl2onnx.common.tree_ensemble import sklearn_threshold +from ..helpers.lgbm_helper import ( + dump_lgbm_booster, modify_tree_for_rule_in_set) def calculate_lightgbm_output_shapes(operator): @@ -25,8 +28,13 @@ def calculate_lightgbm_output_shapes(operator): op = operator.raw_operator if hasattr(op, "_model_dict"): objective = op._model_dict['objective'] - else: + elif hasattr(op, 'objective_'): objective = op.objective_ + else: + raise RuntimeError( # pragma: no cover + "Unable to find attributes '_model_dict' or 'objective_' in " + "instance of type %r (list of attributes=%r)." % ( + type(op), dir(op))) if objective.startswith('binary') or objective.startswith('multiclass'): return calculate_linear_classifier_output_shapes(operator) if objective.startswith('regression'): # pragma: no cover @@ -101,12 +109,12 @@ def _parse_tree_structure(tree_id, class_id, learning_rate, tree_structure, attr attrs['nodes_nodeids'].append(node_id) attrs['nodes_featureids'].append(tree_structure['split_feature']) - attrs['nodes_modes'].append( - _translate_split_criterion(tree_structure['decision_type'])) + mode = _translate_split_criterion(tree_structure['decision_type']) + attrs['nodes_modes'].append(mode) + if isinstance(tree_structure['threshold'], str): try: # pragma: no cover - attrs['nodes_values'].append( # pragma: no cover - float(tree_structure['threshold'])) + th = float(tree_structure['threshold']) # pragma: no cover except ValueError as e: # pragma: no cover import pprint text = pprint.pformat(tree_structure) @@ -115,13 +123,24 @@ def _parse_tree_structure(tree_id, class_id, learning_rate, tree_structure, attr raise TypeError("threshold must be a number not '{}'" "\n{}".format(tree_structure['threshold'], text)) from e else: - attrs['nodes_values'].append(tree_structure['threshold']) + th = tree_structure['threshold'] + if mode == 'BRANCH_LEQ': + th2 = sklearn_threshold(th, numpy.float32, mode) + else: + # other decision criteria are not implemented + th2 = th + attrs['nodes_values'].append(th2) # Assume left is the true branch and right is the false branch attrs['nodes_truenodeids'].append(left_id) attrs['nodes_falsenodeids'].append(right_id) if tree_structure['default_left']: - attrs['nodes_missing_value_tracks_true'].append(1) + # attrs['nodes_missing_value_tracks_true'].append(1) + if (tree_structure["missing_type"] in ('None', None) and + float(tree_structure['threshold']) < 0.0): + attrs['nodes_missing_value_tracks_true'].append(0) + else: + attrs['nodes_missing_value_tracks_true'].append(1) else: attrs['nodes_missing_value_tracks_true'].append(0) attrs['nodes_hitrates'].append(1.) @@ -138,8 +157,8 @@ def _parse_node(tree_id, class_id, node_id, node_id_pool, node_pyid_pool, """ Parses nodes. """ - if (hasattr(node, 'left_child') and hasattr(node, 'right_child')) or \ - ('left_child' in node and 'right_child' in node): + if ((hasattr(node, 'left_child') and hasattr(node, 'right_child')) or + ('left_child' in node and 'right_child' in node)): left_pyid = id(node['left_child']) right_pyid = id(node['right_child']) @@ -184,7 +203,12 @@ def _parse_node(tree_id, class_id, node_id, node_id_pool, node_pyid_pool, attrs['nodes_truenodeids'].append(left_id) attrs['nodes_falsenodeids'].append(right_id) if node['default_left']: - attrs['nodes_missing_value_tracks_true'].append(1) + # attrs['nodes_missing_value_tracks_true'].append(1) + if (node['missing_type'] in ('None', None) and + float(node['threshold']) < 0.0): + attrs['nodes_missing_value_tracks_true'].append(0) + else: + attrs['nodes_missing_value_tracks_true'].append(1) else: attrs['nodes_missing_value_tracks_true'].append(0) attrs['nodes_hitrates'].append(1.) @@ -230,9 +254,18 @@ def convert_lightgbm(scope, operator, container): some modifications. It implements converters for models in :epkg:`lightgbm`. """ + verbose = container.verbose gbm_model = operator.raw_operator - gbm_text = gbm_model.booster_.dump_model() - modify_tree_for_rule_in_set(gbm_text, use_float=True) + if hasattr(gbm_model, '_model_dict_info'): + gbm_text, info = gbm_model._model_dict_info + else: + if verbose >= 2: + print("[convert_lightgbm] dump_model") + gbm_text, info = dump_lgbm_booster(gbm_model.booster_, verbose=verbose) + if verbose >= 2: + print("[convert_lightgbm] modify_tree_for_rule_in_set") + modify_tree_for_rule_in_set(gbm_text, use_float=True, verbose=verbose, + info=info) attrs = get_default_tree_classifier_attribute_pairs() attrs['name'] = operator.full_name @@ -254,7 +287,13 @@ def convert_lightgbm(scope, operator, container): gbm_text['objective'])) # Use the same algorithm to parse the tree - for i, tree in enumerate(gbm_text['tree_info']): + if verbose >= 2: + from tqdm import tqdm + loop = tqdm(gbm_text['tree_info']) + loop.set_description("parse") + else: + loop = gbm_text['tree_info'] + for i, tree in enumerate(loop): tree_id = i class_id = tree_id % n_classes # tree['shrinkage'] --> LightGbm provides figures with it already. @@ -262,6 +301,9 @@ def convert_lightgbm(scope, operator, container): _parse_tree_structure( tree_id, class_id, learning_rate, tree['tree_structure'], attrs) + if verbose >= 2: + print("[convert_lightgbm] onnx") + # Sort nodes_* attributes. For one tree, its node indexes should appear in an ascent order in nodes_nodeids. Nodes # from a tree with a smaller tree index should appear before trees with larger indexes in nodes_nodeids. node_numbers_per_tree = Counter(attrs['nodes_treeids']) @@ -404,88 +446,5 @@ def convert_lightgbm(scope, operator, container): operator.output_full_names, name=scope.get_unique_operator_name('Identity')) - -def modify_tree_for_rule_in_set(gbm, use_float=False): # pylint: disable=R1710 - """ - LightGBM produces sometimes a tree with a node set - to use rule ``==`` to a set of values (= in set), - the values are separated by ``||``. - This function unfold theses nodes. A child looks - like the following: - - .. runpython:: - :showcode: - :warningout: DeprecationWarning - - import pprint - from mlprodict.onnx_conv.operator_converters.conv_lightgbm import modify_tree_for_rule_in_set - - tree = {'decision_type': '==', - 'default_left': True, - 'internal_count': 6805, - 'internal_value': 0.117558, - 'left_child': {'leaf_count': 4293, - 'leaf_index': 18, - 'leaf_value': 0.003519117642745049}, - 'missing_type': 'None', - 'right_child': {'leaf_count': 2512, - 'leaf_index': 25, - 'leaf_value': 0.012305307958365394}, - 'split_feature': 24, - 'split_gain': 12.233599662780762, - 'split_index': 24, - 'threshold': '10||12||13'} - - modify_tree_for_rule_in_set(tree) - - pprint.pprint(tree) - """ - if 'tree_info' in gbm: - for tree in gbm['tree_info']: - modify_tree_for_rule_in_set(tree, use_float=use_float) - return - - if 'tree_structure' in gbm: - modify_tree_for_rule_in_set(gbm['tree_structure'], use_float=use_float) - return - - if 'decision_type' not in gbm: - return - - def recursive_call(this): - if 'left_child' in this: - modify_tree_for_rule_in_set( - this['left_child'], use_float=use_float) - if 'right_child' in this: - modify_tree_for_rule_in_set( - this['right_child'], use_float=use_float) - - def str2number(val): - if use_float: - return float(val) - else: - try: - return int(val) - except ValueError: # pragma: no cover - return float(val) - - dec = gbm['decision_type'] - if dec != '==': - return recursive_call(gbm) - - th = gbm['threshold'] - if not isinstance(th, str) or '||' not in th: - return recursive_call(gbm) - - pos = th.index('||') - th1 = str2number(th[:pos]) - - rest = th[pos + 2:] - if '||' not in rest: - rest = str2number(rest) - - gbm['threshold'] = th1 - new_node = gbm.copy() - gbm['right_child'] = new_node - new_node['threshold'] = rest - return recursive_call(gbm) + if verbose >= 2: + print("[convert_lightgbm] end") diff --git a/mlprodict/onnx_conv/parsers/parse_lightgbm.py b/mlprodict/onnx_conv/parsers/parse_lightgbm.py index 1dbfe5dff..2245543e7 100644 --- a/mlprodict/onnx_conv/parsers/parse_lightgbm.py +++ b/mlprodict/onnx_conv/parsers/parse_lightgbm.py @@ -18,17 +18,21 @@ class WrappedLightGbmBooster: def __init__(self, booster): self.booster_ = booster - self._model_dict = self.booster_.dump_model() - self.classes_ = self._generate_classes(self._model_dict) - self.n_features_ = len(self._model_dict['feature_names']) - if self._model_dict['objective'].startswith('binary'): + self.n_features_ = self.booster_.feature_name() + self.objective_ = self.get_objective() + if self.objective_.startswith('binary'): self.operator_name = 'LgbmClassifier' - elif self._model_dict['objective'].startswith('regression'): # pragma: no cover + self.classes_ = self._generate_classes(booster) + elif self.objective_.startswith('multiclass'): + self.operator_name = 'LgbmClassifier' + self.classes_ = self._generate_classes(booster) + elif self.objective_.startswith('regression'): # pragma: no cover self.operator_name = 'LgbmRegressor' else: # pragma: no cover - raise NotImplementedError('Unsupported LightGbm objective: {}'.format( - self._model_dict['objective'])) - if self._model_dict.get('average_output', False): + raise NotImplementedError( + 'Unsupported LightGbm objective: %r.' % self.objective_) + average_output = self.booster_.attr('average_output') + if average_output: self.boosting_type = 'rf' else: # Other than random forest, other boosting types do not affect later conversion. @@ -36,10 +40,27 @@ def __init__(self, booster): self.boosting_type = 'gbdt' @staticmethod - def _generate_classes(model_dict): - if model_dict['num_class'] == 1: + def _generate_classes(booster): + if isinstance(booster, dict): + num_class = booster['num_class'] + else: + num_class = booster.attr('num_class') + if num_class is None: + dp = booster.dump_model(num_iteration=1) + num_class = dp['num_class'] + if num_class == 1: return numpy.asarray([0, 1]) - return numpy.arange(model_dict['num_class']) + return numpy.arange(num_class) + + def get_objective(self): + "Returns the objective." + if hasattr(self, 'objective_') and self.objective_ is not None: + return self.objective_ + objective = self.booster_.attr('objective') + if objective is not None: + return objective + dp = self.booster_.dump_model(num_iteration=1) + return dp['objective'] class WrappedLightGbmBoosterClassifier(ClassifierMixin): @@ -48,9 +69,11 @@ class WrappedLightGbmBoosterClassifier(ClassifierMixin): """ def __init__(self, wrapped): # pylint: disable=W0231 - for k in {'boosting_type', '_model_dict', 'operator_name', - 'classes_', 'booster_', 'n_features_'}: - setattr(self, k, getattr(wrapped, k)) + for k in {'boosting_type', '_model_dict', '_model_dict_info', + 'operator_name', 'classes_', 'booster_', 'n_features_', + 'objective_', 'boosting_type', 'n_features_'}: + if hasattr(wrapped, k): + setattr(self, k, getattr(wrapped, k)) class MockWrappedLightGbmBoosterClassifier(WrappedLightGbmBoosterClassifier): @@ -60,12 +83,28 @@ class MockWrappedLightGbmBoosterClassifier(WrappedLightGbmBoosterClassifier): def __init__(self, tree): # pylint: disable=W0231 self.dumped_ = tree + self.is_mock = True def dump_model(self): "mock dump_model method" self.visited = True return self.dumped_ + def feature_name(self): + "Returns binary features names." + return [0, 1] + + def attr(self, key): + "Returns default values for common attributes." + if key == 'objective': + return "binary" + if key == 'num_class': + return 1 + if key == 'average_output': + return None + raise KeyError( # pragma: no cover + "No response for %r." % key) + def lightgbm_parser(scope, model, inputs, custom_parsers=None): """ @@ -78,16 +117,20 @@ def lightgbm_parser(scope, model, inputs, custom_parsers=None): if len(inputs) == 1: wrapped = WrappedLightGbmBooster(model) - if wrapped._model_dict['objective'].startswith('binary'): + objective = wrapped.get_objective() + if objective.startswith('binary'): + wrapped = WrappedLightGbmBoosterClassifier(wrapped) + return _parse_sklearn_classifier( + scope, wrapped, inputs, custom_parsers=custom_parsers) + if objective.startswith('multiclass'): wrapped = WrappedLightGbmBoosterClassifier(wrapped) return _parse_sklearn_classifier( scope, wrapped, inputs, custom_parsers=custom_parsers) - if wrapped._model_dict['objective'].startswith('regression'): # pragma: no cover + if objective.startswith('regression'): # pragma: no cover return _parse_sklearn_simple_model( scope, wrapped, inputs, custom_parsers=custom_parsers) raise NotImplementedError( # pragma: no cover - "Objective '{}' is not implemented yet.".format( - wrapped._model_dict['objective'])) + "Objective '{}' is not implemented yet.".format(objective)) # Multiple columns this_operator = scope.declare_local_operator('LightGBMConcat') diff --git a/mlprodict/onnx_tools/onnx_export.py b/mlprodict/onnx_tools/onnx_export.py index b7ef73067..fe22fca5b 100644 --- a/mlprodict/onnx_tools/onnx_export.py +++ b/mlprodict/onnx_tools/onnx_export.py @@ -5,7 +5,6 @@ .. versionadded:: 0.7 """ -from textwrap import dedent import numpy from jinja2 import Template import autopep8 @@ -13,275 +12,8 @@ from onnx import numpy_helper from .onnx2py_helper import ( _var_as_dict, guess_proto_dtype, guess_proto_dtype_name) - - -_onnx_templates = dedent(""" - import numpy - from onnx import numpy_helper, TensorProto - from onnx.helper import ( - make_model, make_node, set_model_props, make_tensor, make_graph, - make_tensor_value_info) - - - def create_model(): - ''' - Converted ``{{ name }}``. - - * producer: {{ producer_name }} - * version: {{ model_version }} - * description: {{ doc_string }} - {%- for key, val in sorted(metadata.items()): -%} - * {{ key }}: {{ val }} - {%- endfor %} - ''' - # containers - print('[containers]') # verbose - initializers = [] - nodes = [] - inputs = [] - outputs = [] - - # opsets - print('[opsets]') # verbose - opsets = {{ opsets }} - target_opset = {{ target_opset }} - - # initializers - print('[initializers]') # verbose - {% for name, value in initializers: %} - {% if len(value.shape) == 0: %} - value = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) - {% else %} - list_value = {{ value.ravel().tolist() }} - value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} - {% endif %} - tensor = numpy_helper.from_array(value, name='{{ name }}') - initializers.append(tensor) - {% endfor %} - - # inputs - print('[inputs]') # verbose - {% for name, type, shape in inputs: %} - value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) - inputs.append(value) - {% endfor %} - - # outputs - print('[outputs]') # verbose - {% for name, type, shape in outputs: %} - value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) - outputs.append(value) - {% endfor %} - - # nodes - print('[nodes]') # verbose - {% for node in nodes: %} - node = make_node( - '{{ node['op_type'] }}', - {{ node['inputs'] }}, - {{ node['outputs'] }}, - {% if node['name']: %}name='{{ node['name'] }}',{% endif %} - {%- for name, value in node['attributes']: -%} - {{ name }}={{ value }}, - {%- endfor -%} - domain='{{ node['domain'] }}') - nodes.append(node) - {% endfor %} - - # graph - print('[graph]') # verbose - graph = make_graph(nodes, '{{ name }}', inputs, outputs, initializers) - onnx_model = make_model(graph) - onnx_model.ir_version = {{ ir_version }} - onnx_model.producer_name = '{{ producer_name }}' - onnx_model.producer_version = '{{ producer_version }}' - onnx_model.domain = '{{ domain }}' - onnx_model.model_version = {{ model_version }} - onnx_model.doc_string = '{{ doc_string }}' - set_model_props(onnx_model, {{ metadata }}) - - # opsets - print('[opset]') # verbose - del onnx_model.opset_import[:] # pylint: disable=E1101 - for dom, value in opsets.items(): - op_set = onnx_model.opset_import.add() - op_set.domain = dom - op_set.version = value - - return onnx_model - - - onnx_model = create_model() -""") - - -_tf2onnx_templates = dedent(""" - import inspect - import collections - import numpy - from onnx import AttributeProto, TensorProto - from onnx.helper import ( - make_model, make_node, set_model_props, make_tensor, make_graph, - make_tensor_value_info) - # from tf2onnx.utils import make_name, make_sure, map_onnx_to_numpy_type - from mlprodict.onnx_tools.exports.tf2onnx_helper import ( - make_name, make_sure, map_onnx_to_numpy_type) - # from tf2onnx.handler import tf_op - # from tf2onnx.graph_builder import GraphBuilder - from mlprodict.onnx_tools.exports.tf2onnx_helper import ( - tf_op, Tf2OnnxConvert, GraphBuilder) - - - @tf_op("{{ name }}") - class Convert{{ name }}Op: - - supported_dtypes = [ - numpy.float32, - ] - - @classmethod - def any_version(cls, opset, ctx, node, **kwargs): - ''' - Converter for ``{{ name }}``. - - * producer: {{ producer_name }} - * version: {{ model_version }} - * description: {{ doc_string }} - {%- for key, val in sorted(metadata.items()): -%} - * {{ key }}: {{ val }} - {%- endfor %} - ''' - oldnode = node - input_name = node.input[0] - onnx_dtype = ctx.get_dtype(input_name) - np_dtype = map_onnx_to_numpy_type(onnx_dtype) - make_sure(np_dtype in Convert{{ name }}Op.supported_dtypes, "Unsupported input type.") - shape = ctx.get_shape(input_name) - varx = {x: x for x in node.input} - - # initializers - if getattr(ctx, 'verbose', False): - print('[initializers] %r' % cls) - {% for name, value in initializers: %} - {% if len(value.shape) == 0: -%} - value = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) - {%- else -%} - {% if value.size > 5: -%} - list_value = {{ value.ravel().tolist() }} - value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} - {%- else -%} - value = numpy.array({{ value.ravel().tolist() }}, dtype=numpy.{{ value.dtype }}){%- - if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} - {%- endif -%}{%- endif %} - varx['{{ name }}'] = ctx.make_const(name=make_name('init_{{ name }}'), np_val=value).name - {% endfor %} - - # nodes - if getattr(ctx, 'verbose', False): - print('[nodes] %r' % cls) - {% for node in nodes: %} - {{ make_tf2onnx_code(target_opset, **node) }} - {% endfor %} - - # finalize - if getattr(ctx, 'verbose', False): - print('[replace_all_inputs] %r' % cls) - ctx.replace_all_inputs(oldnode.output[0], node.output[0]) - ctx.remove_node(oldnode.name) - - @classmethod - def version_13(cls, ctx, node, **kwargs): - return cls.any_version(13, ctx, node, **kwargs) - - - def create_model(): - inputs = [] - outputs = [] - - # inputs - print('[inputs]') # verbose - {% for name, type, shape in inputs: %} - value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) - inputs.append(value) - {% endfor %} - - # outputs - print('[outputs]') # verbose - {% for name, type, shape in outputs: %} - value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) - outputs.append(value) - {% endfor %} - - inames = [i.name for i in inputs] - onames = [i.name for i in outputs] - node = make_node('{{ name }}', inames, onames, name='{{ name }}') - - # graph - print('[graph]') # verbose - graph = make_graph([node], '{{ name }}', inputs, outputs) - onnx_model = make_model(graph) - onnx_model.ir_version = {{ ir_version }} - onnx_model.producer_name = '{{ producer_name }}' - onnx_model.producer_version = '{{ producer_version }}' - onnx_model.domain = '{{ domain }}' - onnx_model.model_version = {{ model_version }} - onnx_model.doc_string = '{{ doc_string }}' - set_model_props(onnx_model, {{ metadata }}) - - # opsets - print('[opset]') # verbose - opsets = {{ opsets }} - del onnx_model.opset_import[:] # pylint: disable=E1101 - for dom, value in opsets.items(): - op_set = onnx_model.opset_import.add() - op_set.domain = dom - op_set.version = value - - return onnx_model - - - onnx_raw = create_model() - onnx_model = Tf2OnnxConvert(onnx_raw, tf_op, target_opset={{ opsets }}).run() -""") - - -_numpy_templates = dedent(""" - import numpy - from mlprodict.onnx_tools.exports.numpy_helper import ( - argmin_use_numpy_select_last_index, - make_slice) - - def numpy_{{name}}({{ inputs[0][0] }}{% for i in inputs[1:]: %}, {{ i[0] }}{% endfor %}): - ''' - Numpy function for ``{{ name }}``. - - * producer: {{ producer_name }} - * version: {{ model_version }} - * description: {{ doc_string }} - {%- for key, val in sorted(metadata.items()): -%} - * {{ key }}: {{ val }} - {%- endfor %} - ''' - # initializers - {% for name, value in initializers: -%} - {% if name not in skip_inits: -%} - {% if len(value.shape) == 0: -%} - {{ name }} = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) - {%- else %}{% if value.size < 10: %} - {{ name }} = numpy.array({{ value.ravel().tolist() }}, dtype=numpy.{{ value.dtype }}) - {%- if len(value.shape) > 1: -%}.reshape({{ value.shape }}){%- endif %} - {% else %} - list_value = {{ value.ravel().tolist() }} - {{ name }} = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} - {% endif %}{% endif %}{% endif %} - {%- endfor %} - - # nodes - {% for node in nodes: %} - {{ make_numpy_code(target_opset, **node) }}{% endfor %} - - return {{ outputs[0][0] }}{% for o in outputs[1:]: %}, {{ o[0] }}{% endfor %} -""") +from .onnx_export_templates import ( + _onnx_templates, _tf2onnx_templates, _numpy_templates) def make_tf2onnx_code(opset, name=None, op_type=None, domain='', diff --git a/mlprodict/onnx_tools/onnx_export_templates.py b/mlprodict/onnx_tools/onnx_export_templates.py new file mode 100644 index 000000000..1454b5094 --- /dev/null +++ b/mlprodict/onnx_tools/onnx_export_templates.py @@ -0,0 +1,276 @@ +""" +@file +@brief Templates to export an ONNX graph in a way it can we created again +with a python script. + +.. versionadded:: 0.7 +""" +from textwrap import dedent + +_onnx_templates = dedent(""" + import numpy + from onnx import numpy_helper, TensorProto + from onnx.helper import ( + make_model, make_node, set_model_props, make_tensor, make_graph, + make_tensor_value_info) + + + def create_model(): + ''' + Converted ``{{ name }}``. + + * producer: {{ producer_name }} + * version: {{ model_version }} + * description: {{ doc_string }} + {%- for key, val in sorted(metadata.items()): -%} + * {{ key }}: {{ val }} + {%- endfor %} + ''' + # containers + print('[containers]') # verbose + initializers = [] + nodes = [] + inputs = [] + outputs = [] + + # opsets + print('[opsets]') # verbose + opsets = {{ opsets }} + target_opset = {{ target_opset }} + + # initializers + print('[initializers]') # verbose + {% for name, value in initializers: %} + {% if len(value.shape) == 0: %} + value = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) + {% else %} + list_value = {{ value.ravel().tolist() }} + value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} + {% endif %} + tensor = numpy_helper.from_array(value, name='{{ name }}') + initializers.append(tensor) + {% endfor %} + + # inputs + print('[inputs]') # verbose + {% for name, type, shape in inputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + inputs.append(value) + {% endfor %} + + # outputs + print('[outputs]') # verbose + {% for name, type, shape in outputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + outputs.append(value) + {% endfor %} + + # nodes + print('[nodes]') # verbose + {% for node in nodes: %} + node = make_node( + '{{ node['op_type'] }}', + {{ node['inputs'] }}, + {{ node['outputs'] }}, + {% if node['name']: %}name='{{ node['name'] }}',{% endif %} + {%- for name, value in node['attributes']: -%} + {{ name }}={{ value }}, + {%- endfor -%} + domain='{{ node['domain'] }}') + nodes.append(node) + {% endfor %} + + # graph + print('[graph]') # verbose + graph = make_graph(nodes, '{{ name }}', inputs, outputs, initializers) + onnx_model = make_model(graph) + onnx_model.ir_version = {{ ir_version }} + onnx_model.producer_name = '{{ producer_name }}' + onnx_model.producer_version = '{{ producer_version }}' + onnx_model.domain = '{{ domain }}' + onnx_model.model_version = {{ model_version }} + onnx_model.doc_string = '{{ doc_string }}' + set_model_props(onnx_model, {{ metadata }}) + + # opsets + print('[opset]') # verbose + del onnx_model.opset_import[:] # pylint: disable=E1101 + for dom, value in opsets.items(): + op_set = onnx_model.opset_import.add() + op_set.domain = dom + op_set.version = value + + return onnx_model + + + onnx_model = create_model() +""") + + +_tf2onnx_templates = dedent(""" + import inspect + import collections + import numpy + from onnx import AttributeProto, TensorProto + from onnx.helper import ( + make_model, make_node, set_model_props, make_tensor, make_graph, + make_tensor_value_info) + # from tf2onnx.utils import make_name, make_sure, map_onnx_to_numpy_type + from mlprodict.onnx_tools.exports.tf2onnx_helper import ( + make_name, make_sure, map_onnx_to_numpy_type) + # from tf2onnx.handler import tf_op + # from tf2onnx.graph_builder import GraphBuilder + from mlprodict.onnx_tools.exports.tf2onnx_helper import ( + tf_op, Tf2OnnxConvert, GraphBuilder) + + + @tf_op("{{ name }}") + class Convert{{ name }}Op: + + supported_dtypes = [ + numpy.float32, + ] + + @classmethod + def any_version(cls, opset, ctx, node, **kwargs): + ''' + Converter for ``{{ name }}``. + + * producer: {{ producer_name }} + * version: {{ model_version }} + * description: {{ doc_string }} + {%- for key, val in sorted(metadata.items()): -%} + * {{ key }}: {{ val }} + {%- endfor %} + ''' + oldnode = node + input_name = node.input[0] + onnx_dtype = ctx.get_dtype(input_name) + np_dtype = map_onnx_to_numpy_type(onnx_dtype) + make_sure(np_dtype in Convert{{ name }}Op.supported_dtypes, "Unsupported input type.") + shape = ctx.get_shape(input_name) + varx = {x: x for x in node.input} + + # initializers + if getattr(ctx, 'verbose', False): + print('[initializers] %r' % cls) + {% for name, value in initializers: %} + {% if len(value.shape) == 0: -%} + value = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) + {%- else -%} + {% if value.size > 5: -%} + list_value = {{ value.ravel().tolist() }} + value = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} + {%- else -%} + value = numpy.array({{ value.ravel().tolist() }}, dtype=numpy.{{ value.dtype }}){%- + if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} + {%- endif -%}{%- endif %} + varx['{{ name }}'] = ctx.make_const(name=make_name('init_{{ name }}'), np_val=value).name + {% endfor %} + + # nodes + if getattr(ctx, 'verbose', False): + print('[nodes] %r' % cls) + {% for node in nodes: %} + {{ make_tf2onnx_code(target_opset, **node) }} + {% endfor %} + + # finalize + if getattr(ctx, 'verbose', False): + print('[replace_all_inputs] %r' % cls) + ctx.replace_all_inputs(oldnode.output[0], node.output[0]) + ctx.remove_node(oldnode.name) + + @classmethod + def version_13(cls, ctx, node, **kwargs): + return cls.any_version(13, ctx, node, **kwargs) + + + def create_model(): + inputs = [] + outputs = [] + + # inputs + print('[inputs]') # verbose + {% for name, type, shape in inputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + inputs.append(value) + {% endfor %} + + # outputs + print('[outputs]') # verbose + {% for name, type, shape in outputs: %} + value = make_tensor_value_info('{{ name }}', {{ type }}, {{ shape }}) + outputs.append(value) + {% endfor %} + + inames = [i.name for i in inputs] + onames = [i.name for i in outputs] + node = make_node('{{ name }}', inames, onames, name='{{ name }}') + + # graph + print('[graph]') # verbose + graph = make_graph([node], '{{ name }}', inputs, outputs) + onnx_model = make_model(graph) + onnx_model.ir_version = {{ ir_version }} + onnx_model.producer_name = '{{ producer_name }}' + onnx_model.producer_version = '{{ producer_version }}' + onnx_model.domain = '{{ domain }}' + onnx_model.model_version = {{ model_version }} + onnx_model.doc_string = '{{ doc_string }}' + set_model_props(onnx_model, {{ metadata }}) + + # opsets + print('[opset]') # verbose + opsets = {{ opsets }} + del onnx_model.opset_import[:] # pylint: disable=E1101 + for dom, value in opsets.items(): + op_set = onnx_model.opset_import.add() + op_set.domain = dom + op_set.version = value + + return onnx_model + + + onnx_raw = create_model() + onnx_model = Tf2OnnxConvert(onnx_raw, tf_op, target_opset={{ opsets }}).run() +""") + + +_numpy_templates = dedent(""" + import numpy + from mlprodict.onnx_tools.exports.numpy_helper import ( + argmin_use_numpy_select_last_index, + make_slice) + + def numpy_{{name}}({{ inputs[0][0] }}{% for i in inputs[1:]: %}, {{ i[0] }}{% endfor %}): + ''' + Numpy function for ``{{ name }}``. + + * producer: {{ producer_name }} + * version: {{ model_version }} + * description: {{ doc_string }} + {%- for key, val in sorted(metadata.items()): -%} + * {{ key }}: {{ val }} + {%- endfor %} + ''' + # initializers + {% for name, value in initializers: -%} + {% if name not in skip_inits: -%} + {% if len(value.shape) == 0: -%} + {{ name }} = numpy.array({{ value }}, dtype=numpy.{{ value.dtype }}) + {%- else %}{% if value.size < 10: %} + {{ name }} = numpy.array({{ value.ravel().tolist() }}, dtype=numpy.{{ value.dtype }}) + {%- if len(value.shape) > 1: -%}.reshape({{ value.shape }}){%- endif %} + {% else %} + list_value = {{ value.ravel().tolist() }} + {{ name }} = numpy.array(list_value, dtype=numpy.{{ value.dtype }}){% if len(value.shape) > 1: %}.reshape({{ value.shape }}){% endif %} + {% endif %}{% endif %}{% endif %} + {%- endfor %} + + # nodes + {% for node in nodes: %} + {{ make_numpy_code(target_opset, **node) }}{% endfor %} + + return {{ outputs[0][0] }}{% for o in outputs[1:]: %}, {{ o[0] }}{% endfor %} +""") diff --git a/mlprodict/onnxrt/ops_cpu/op_conv_.cpp b/mlprodict/onnxrt/ops_cpu/op_conv_.cpp index e68a455bb..722386b7b 100644 --- a/mlprodict/onnxrt/ops_cpu/op_conv_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_conv_.cpp @@ -35,8 +35,10 @@ class Conv : public ConvPoolCommon { protected: - void compute_gil_free(py::array_t X, py::array_t W, - py::array_t B, py::array_t& Y, + void compute_gil_free(py::array_t X, + py::array_t W, + py::array_t B, + py::array_t& Y, const std::vector& input_shape, const std::vector& output_shape, const std::vector& kernel_shape, @@ -90,7 +92,7 @@ py::array_t Conv::compute(py::array_t Y(y_dims); + py::array_t Y(y_dims); { py::gil_scoped_release release; compute_gil_free(X, W, B, Y, @@ -104,7 +106,10 @@ py::array_t Conv::compute(py::array_t void Conv::compute_gil_free( - py::array_t X, py::array_t W, py::array_t B, py::array_t& Y, + py::array_t X, + py::array_t W, + py::array_t B, + py::array_t& Y, const std::vector& input_shape, const std::vector& output_shape, const std::vector& kernel_shape, diff --git a/mlprodict/onnxrt/ops_cpu/op_conv_transpose_.cpp b/mlprodict/onnxrt/ops_cpu/op_conv_transpose_.cpp index 71a38a661..f2658b8be 100644 --- a/mlprodict/onnxrt/ops_cpu/op_conv_transpose_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_conv_transpose_.cpp @@ -34,13 +34,13 @@ class ConvTranspose : ConvPoolCommon { ConvTranspose(); void init(const std::string &auto_pad, - py::array_t dilations, + py::array_t dilations, int64_t group, - py::array_t kernel_shape, - py::array_t pads, - py::array_t strides, - py::array_t output_padding, - py::array_t output_shape); + py::array_t kernel_shape, + py::array_t pads, + py::array_t strides, + py::array_t output_padding, + py::array_t output_shape); py::array_t compute(py::array_t X, py::array_t W, @@ -51,8 +51,10 @@ class ConvTranspose : ConvPoolCommon { void compute_kernel_shape(const std::vector& weight_shape, std::vector& kernel_shape) const; - void compute_gil_free(py::array_t X, py::array_t W, - py::array_t B, py::array_t& Y, + void compute_gil_free(py::array_t X, + py::array_t W, + py::array_t B, + py::array_t& Y, const std::vector& input_shape, const std::vector& output_shape, const std::vector& kernel_shape, @@ -84,13 +86,13 @@ ConvTranspose::ConvTranspose() : ConvPoolCommon() { template void ConvTranspose::init( const std::string &auto_pad, - py::array_t dilations, + py::array_t dilations, int64_t group, - py::array_t kernel_shape, - py::array_t pads, - py::array_t strides, - py::array_t output_padding, - py::array_t output_shape) { + py::array_t kernel_shape, + py::array_t pads, + py::array_t strides, + py::array_t output_padding, + py::array_t output_shape) { ConvPoolCommon::init(auto_pad, dilations, group, kernel_shape, pads, strides); array2vector(output_padding_, output_padding, int64_t); array2vector(output_shape_, output_shape, int64_t); @@ -160,7 +162,7 @@ py::array_t ConvTranspose::compute(py::array_t output_shape(y_dims.begin() + 2, y_dims.end()); - py::array_t Y(y_dims); + py::array_t Y(y_dims); { py::gil_scoped_release release; compute_gil_free(X, W, B, Y, @@ -227,7 +229,10 @@ void ConvTranspose::infer_output_shape( template void ConvTranspose::compute_gil_free( - py::array_t X, py::array_t W, py::array_t B, py::array_t& Y, + py::array_t X, + py::array_t W, + py::array_t B, + py::array_t& Y, const std::vector& input_shape, const std::vector& output_shape, const std::vector& kernel_shape, diff --git a/mlprodict/onnxrt/ops_cpu/op_max_pool_.cpp b/mlprodict/onnxrt/ops_cpu/op_max_pool_.cpp index 2cc620fbf..5fa2a3b0c 100644 --- a/mlprodict/onnxrt/ops_cpu/op_max_pool_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_max_pool_.cpp @@ -35,12 +35,12 @@ class MaxPool : ConvPoolCommon { MaxPool(); void init(const std::string &auto_pad, - py::array_t dilations, + py::array_t dilations, int64_t ceil_mode, int64_t storage_order, - py::array_t kernel_shape, - py::array_t pads, - py::array_t strides); + py::array_t kernel_shape, + py::array_t pads, + py::array_t strides); py::tuple compute(py::array_t X) const; @@ -74,7 +74,9 @@ class MaxPool : ConvPoolCommon { int64_t pad_needed, int64_t dilation) const; - void compute_gil_free(py::array_t X, py::array_t& Y, py::array_t* I, + void compute_gil_free(py::array_t X, + py::array_t& Y, + py::array_t* I, const std::vector& kernel_shape, const std::vector& pads, const std::vector& strides, @@ -93,12 +95,12 @@ MaxPool::MaxPool() : ConvPoolCommon() { template void MaxPool::init( const std::string &auto_pad, - py::array_t dilations, + py::array_t dilations, int64_t ceil_mode, int64_t storage_order, - py::array_t kernel_shape, - py::array_t pads, - py::array_t strides) { + py::array_t kernel_shape, + py::array_t pads, + py::array_t strides) { ConvPoolCommon::init(auto_pad, dilations, 0, kernel_shape, pads, strides); ceil_mode_ = ceil_mode; storage_order_ = storage_order; @@ -239,8 +241,8 @@ py::tuple MaxPool::compute(py::array_t output_dims = SetOutputSize(x_dims, x_dims[1], &pads, &strides, &kernel_shape, &dilations); - py::array_t Y(output_dims); - py::array_t I(output_dims); + py::array_t Y(output_dims); + py::array_t I(output_dims); { py::gil_scoped_release release; compute_gil_free(X, Y, &I, kernel_shape, pads, strides, dilations, x_dims, output_dims); @@ -468,7 +470,9 @@ struct MaxPool3DTask { template void MaxPool::compute_gil_free( - py::array_t X, py::array_t& Y, py::array_t* I, + py::array_t X, + py::array_t& Y, + py::array_t* I, const std::vector& kernel_shape, const std::vector& pads, const std::vector& strides, diff --git a/mlprodict/onnxrt/ops_cpu/op_svm_classifier_.cpp b/mlprodict/onnxrt/ops_cpu/op_svm_classifier_.cpp index 081d94e8a..667122fc5 100644 --- a/mlprodict/onnxrt/ops_cpu/op_svm_classifier_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_svm_classifier_.cpp @@ -25,28 +25,30 @@ class RuntimeSVMClassifier : public RuntimeSVMCommon ~RuntimeSVMClassifier(); void init( - py::array_t classlabels_int64s, + py::array_t classlabels_int64s, const std::vector& classlabels_strings, - py::array_t coefficients, - py::array_t kernel_params, + py::array_t coefficients, + py::array_t kernel_params, const std::string& kernel_type, const std::string& post_transform, - py::array_t prob_a, - py::array_t prob_b, - py::array_t rho, - py::array_t support_vectors, - py::array_t vectors_per_class + py::array_t prob_a, + py::array_t prob_b, + py::array_t rho, + py::array_t support_vectors, + py::array_t vectors_per_class ); - py::tuple compute(py::array_t X) const; + py::tuple compute(py::array_t X) const; private: void Initialize(); void compute_gil_free(const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Y, - py::array_t& Z, int64_t z_stride) const; + const py::array_t& X, + py::array_t& Y, + py::array_t& Z, + int64_t z_stride) const; void compute_gil_free_loop(const NTYPE * x_data, int64_t* y_data, NTYPE * z_data) const; @@ -66,17 +68,17 @@ RuntimeSVMClassifier::~RuntimeSVMClassifier() { template void RuntimeSVMClassifier::init( - py::array_t classlabels_int64s, + py::array_t classlabels_int64s, const std::vector& classlabels_strings, - py::array_t coefficients, - py::array_t kernel_params, + py::array_t coefficients, + py::array_t kernel_params, const std::string& kernel_type, const std::string& post_transform, - py::array_t prob_a, - py::array_t prob_b, - py::array_t rho, - py::array_t support_vectors, - py::array_t vectors_per_class + py::array_t prob_a, + py::array_t prob_b, + py::array_t rho, + py::array_t support_vectors, + py::array_t vectors_per_class ) { RuntimeSVMCommon::init( coefficients, kernel_params, kernel_type, @@ -155,7 +157,7 @@ int _set_score_svm(int64_t* output_data, NTYPE max_weight, const int64_t maxclas template -py::tuple RuntimeSVMClassifier::compute(py::array_t X) const { +py::tuple RuntimeSVMClassifier::compute(py::array_t X) const { // const Tensor& X = *context->Input(0); // const TensorShape& x_shape = X.Shape(); std::vector x_dims; @@ -175,8 +177,8 @@ py::tuple RuntimeSVMClassifier::compute(py::array_t X) const { std::vector dims{N, nb_columns}; - py::array_t Y(N); // one target only - py::array_t Z(N * nb_columns); // one target only + py::array_t Y(N); // one target only + py::array_t Z(N * nb_columns); // one target only { py::gil_scoped_release release; compute_gil_free(x_dims, N, stride, X, Y, Z, nb_columns); @@ -368,8 +370,9 @@ void RuntimeSVMClassifier::compute_gil_free_loop( template void RuntimeSVMClassifier::compute_gil_free( const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, - py::array_t& Y, py::array_t& Z, + const py::array_t& X, + py::array_t& Y, + py::array_t& Z, int64_t z_stride) const { auto Y_ = Y.mutable_unchecked<1>(); auto Z_ = _mutable_unchecked1(Z); // Z.mutable_unchecked<(size_t)1>(); diff --git a/mlprodict/onnxrt/ops_cpu/op_svm_common_.hpp b/mlprodict/onnxrt/ops_cpu/op_svm_common_.hpp index ac133403f..055d61cd6 100644 --- a/mlprodict/onnxrt/ops_cpu/op_svm_common_.hpp +++ b/mlprodict/onnxrt/ops_cpu/op_svm_common_.hpp @@ -52,12 +52,12 @@ class RuntimeSVMCommon RuntimeSVMCommon(int omp_N) { omp_N_ = omp_N; } ~RuntimeSVMCommon() { } - void init(py::array_t coefficients, - py::array_t kernel_params, + void init(py::array_t coefficients, + py::array_t kernel_params, const std::string& kernel_type, const std::string& post_transform, - py::array_t rho, - py::array_t support_vectors); + py::array_t rho, + py::array_t support_vectors); NTYPE kernel_dot_gil_free( @@ -75,12 +75,12 @@ class RuntimeSVMCommon template void RuntimeSVMCommon::init( - py::array_t coefficients, - py::array_t kernel_params, + py::array_t coefficients, + py::array_t kernel_params, const std::string& kernel_type, const std::string& post_transform, - py::array_t rho, - py::array_t support_vectors + py::array_t rho, + py::array_t support_vectors ) { kernel_type_ = to_KERNEL(kernel_type); array2vector(support_vectors_, support_vectors, NTYPE); @@ -174,12 +174,12 @@ int RuntimeSVMCommon::omp_get_max_threads() { #endif } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } diff --git a/mlprodict/onnxrt/ops_cpu/op_svm_regressor_.cpp b/mlprodict/onnxrt/ops_cpu/op_svm_regressor_.cpp index b284af56a..e290ebfa7 100644 --- a/mlprodict/onnxrt/ops_cpu/op_svm_regressor_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_svm_regressor_.cpp @@ -17,24 +17,25 @@ class RuntimeSVMRegressor : public RuntimeSVMCommon ~RuntimeSVMRegressor(); void init( - py::array_t coefficients, - py::array_t kernel_params, + py::array_t coefficients, + py::array_t kernel_params, const std::string& kernel_type, int64_t n_supports, int64_t one_class, const std::string& post_transform, - py::array_t rho, - py::array_t support_vectors + py::array_t rho, + py::array_t support_vectors ); - py::array_t compute(py::array_t X) const; + py::array_t compute(py::array_t X) const; private: void Initialize(); void compute_gil_free(const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Z) const; + const py::array_t& X, + py::array_t& Z) const; }; @@ -50,14 +51,14 @@ RuntimeSVMRegressor::~RuntimeSVMRegressor() { template void RuntimeSVMRegressor::init( - py::array_t coefficients, - py::array_t kernel_params, + py::array_t coefficients, + py::array_t kernel_params, const std::string& kernel_type, int64_t n_supports, int64_t one_class, const std::string& post_transform, - py::array_t rho, - py::array_t support_vectors + py::array_t rho, + py::array_t support_vectors ) { RuntimeSVMCommon::init( coefficients, kernel_params, kernel_type, @@ -84,7 +85,7 @@ void RuntimeSVMRegressor::Initialize() { template -py::array_t RuntimeSVMRegressor::compute(py::array_t X) const { +py::array_t RuntimeSVMRegressor::compute(py::array_t X) const { // const Tensor& X = *context->Input(0); // const TensorShape& x_shape = X.Shape(); std::vector x_dims; @@ -95,7 +96,7 @@ py::array_t RuntimeSVMRegressor::compute(py::array_t X) con int64_t stride = x_dims.size() == 1 ? x_dims[0] : x_dims[1]; int64_t N = x_dims.size() == 1 ? 1 : x_dims[0]; - py::array_t Z(x_dims[0]); // one target only + py::array_t Z(x_dims[0]); // one target only { py::gil_scoped_release release; compute_gil_free(x_dims, N, stride, X, Z); @@ -125,7 +126,8 @@ py::array_t RuntimeSVMRegressor::compute(py::array_t X) con template void RuntimeSVMRegressor::compute_gil_free( const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Z) const { + const py::array_t& X, + py::array_t& Z) const { auto Z_ = _mutable_unchecked1(Z); // Z.mutable_unchecked<(size_t)1>(); const NTYPE* x_data = X.data(0); diff --git a/mlprodict/onnxrt/ops_cpu/op_tfidfvectorizer_.cpp b/mlprodict/onnxrt/ops_cpu/op_tfidfvectorizer_.cpp index cfa177fd6..1ea8789dc 100644 --- a/mlprodict/onnxrt/ops_cpu/op_tfidfvectorizer_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_tfidfvectorizer_.cpp @@ -92,11 +92,11 @@ class RuntimeTfIdfVectorizer { const std::vector& weights); ~RuntimeTfIdfVectorizer() { } - py::array_t Compute(py::array_t X) const; + py::array_t Compute(py::array_t X) const; private: - void ComputeImpl(const py::array_t& X, + void ComputeImpl(const py::array_t& X, ptrdiff_t row_num, size_t row_size, std::vector& frequencies) const; @@ -225,12 +225,12 @@ void RuntimeTfIdfVectorizer::Init( } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } @@ -250,7 +250,7 @@ py::array_t RuntimeTfIdfVectorizer::OutputResult( const auto row_size = output_size_; auto total_dims = flattened_dimension(output_dims); - py::array_t Y(total_dims); + py::array_t Y(total_dims); auto output_data_ = _mutable_unchecked1(Y); float* output_data = (float*)output_data_.data(0); @@ -292,7 +292,8 @@ py::array_t RuntimeTfIdfVectorizer::OutputResult( } void RuntimeTfIdfVectorizer::ComputeImpl( - const py::array_t& X, ptrdiff_t row_num, size_t row_size, + const py::array_t& X, + ptrdiff_t row_num, size_t row_size, std::vector& frequencies) const { const auto elem_size = sizeof(int64_t); @@ -339,7 +340,7 @@ void RuntimeTfIdfVectorizer::ComputeImpl( } } -py::array_t RuntimeTfIdfVectorizer::Compute(py::array_t X) const { +py::array_t RuntimeTfIdfVectorizer::Compute(py::array_t X) const { std::vector input_shape; arrayshape2vector(input_shape, X); const size_t total_items = flattened_dimension(input_shape); diff --git a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier_.cpp b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier_.cpp index 9a8e93315..d42e40f8c 100644 --- a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier_.cpp @@ -66,26 +66,26 @@ class RuntimeTreeEnsembleClassifier ~RuntimeTreeEnsembleClassifier(); void init( - py::array_t base_values, // 0 - py::array_t class_ids, // 1 - py::array_t class_nodeids, // 2 - py::array_t class_treeids, // 3 - py::array_t class_weights, // 4 - py::array_t classlabels_int64s, // 5 + py::array_t base_values, // 0 + py::array_t class_ids, // 1 + py::array_t class_nodeids, // 2 + py::array_t class_treeids, // 3 + py::array_t class_weights, // 4 + py::array_t classlabels_int64s, // 5 const std::vector& classlabels_strings, // 6 - py::array_t nodes_falsenodeids, // 7 - py::array_t nodes_featureids, // 8 - py::array_t nodes_hitrates, // 9 - py::array_t nodes_missing_value_tracks_true, // 10 + py::array_t nodes_falsenodeids, // 7 + py::array_t nodes_featureids, // 8 + py::array_t nodes_hitrates, // 9 + py::array_t nodes_missing_value_tracks_true, // 10 const std::vector& nodes_modes, // 11 - py::array_t nodes_nodeids, // 12 - py::array_t nodes_treeids, // 13 - py::array_t nodes_truenodeids, // 14 - py::array_t nodes_values, // 15 + py::array_t nodes_nodeids, // 12 + py::array_t nodes_treeids, // 13 + py::array_t nodes_truenodeids, // 14 + py::array_t nodes_values, // 15 const std::string& post_transform // 16 ); - py::tuple compute(py::array_t X) const; + py::tuple compute(py::array_t X) const; void ProcessTreeNode(std::vector& classes, std::vector& filled, @@ -102,8 +102,9 @@ class RuntimeTreeEnsembleClassifier void Initialize(); void compute_gil_free(const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Y, - py::array_t& Z) const; + const py::array_t& X, + py::array_t& Y, + py::array_t& Z) const; int64_t _set_score_binary(int64_t i, int& write_additional_scores, @@ -145,22 +146,22 @@ int RuntimeTreeEnsembleClassifier::omp_get_max_threads() { template void RuntimeTreeEnsembleClassifier::init( - py::array_t base_values, - py::array_t class_ids, - py::array_t class_nodeids, - py::array_t class_treeids, - py::array_t class_weights, - py::array_t classlabels_int64s, + py::array_t base_values, + py::array_t class_ids, + py::array_t class_nodeids, + py::array_t class_treeids, + py::array_t class_weights, + py::array_t classlabels_int64s, const std::vector& classlabels_strings, - py::array_t nodes_falsenodeids, - py::array_t nodes_featureids, - py::array_t nodes_hitrates, - py::array_t nodes_missing_value_tracks_true, + py::array_t nodes_falsenodeids, + py::array_t nodes_featureids, + py::array_t nodes_hitrates, + py::array_t nodes_missing_value_tracks_true, const std::vector& nodes_modes, - py::array_t nodes_nodeids, - py::array_t nodes_treeids, - py::array_t nodes_truenodeids, - py::array_t nodes_values, + py::array_t nodes_nodeids, + py::array_t nodes_treeids, + py::array_t nodes_truenodeids, + py::array_t nodes_values, const std::string& post_transform ) { array2vector(nodes_treeids_, nodes_treeids, int64_t); @@ -371,7 +372,7 @@ int64_t RuntimeTreeEnsembleClassifier::_set_score_binary( template -py::tuple RuntimeTreeEnsembleClassifier::compute(py::array_t X) const { +py::tuple RuntimeTreeEnsembleClassifier::compute(py::array_t X) const { // const Tensor& X = *context->Input(0); // const TensorShape& x_shape = X.Shape(); std::vector x_dims; @@ -385,8 +386,8 @@ py::tuple RuntimeTreeEnsembleClassifier::compute(py::array_t X) co // Tensor* Y = context->Output(0, TensorShape({N})); // auto* Z = context->Output(1, TensorShape({N, class_count_})); - py::array_t Y(x_dims[0]); - py::array_t Z(x_dims[0] * class_count_); + py::array_t Y(x_dims[0]); + py::array_t Z(x_dims[0] * class_count_); { py::gil_scoped_release release; @@ -396,12 +397,12 @@ py::tuple RuntimeTreeEnsembleClassifier::compute(py::array_t X) co } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } @@ -409,7 +410,9 @@ py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array template void RuntimeTreeEnsembleClassifier::compute_gil_free( const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Y, py::array_t& Z) const { + const py::array_t& X, + py::array_t& Y, + py::array_t& Z) const { auto Y_ = Y.mutable_unchecked<1>(); auto Z_ = _mutable_unchecked1(Z); // Z.mutable_unchecked<(size_t)1>(); const NTYPE* x_data = X.data(0); diff --git a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier_p_.cpp b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier_p_.cpp index 67e726403..053333ffb 100644 --- a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier_p_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier_p_.cpp @@ -19,27 +19,27 @@ class RuntimeTreeEnsembleClassifierP : public RuntimeTreeEnsembleCommonP ~RuntimeTreeEnsembleClassifierP(); void init( - py::array_t base_values, // 0 - py::array_t class_ids, // 1 - py::array_t class_nodeids, // 2 - py::array_t class_treeids, // 3 - py::array_t class_weights, // 4 - py::array_t classlabels_int64s, // 5 + py::array_t base_values, // 0 + py::array_t class_ids, // 1 + py::array_t class_nodeids, // 2 + py::array_t class_treeids, // 3 + py::array_t class_weights, // 4 + py::array_t classlabels_int64s, // 5 const std::vector& classlabels_strings, // 6 - py::array_t nodes_falsenodeids, // 7 - py::array_t nodes_featureids, // 8 - py::array_t nodes_hitrates, // 9 - py::array_t nodes_missing_value_tracks_true, // 10 + py::array_t nodes_falsenodeids, // 7 + py::array_t nodes_featureids, // 8 + py::array_t nodes_hitrates, // 9 + py::array_t nodes_missing_value_tracks_true, // 10 const std::vector& nodes_modes, // 11 - py::array_t nodes_nodeids, // 12 - py::array_t nodes_treeids, // 13 - py::array_t nodes_truenodeids, // 14 - py::array_t nodes_values, // 15 + py::array_t nodes_nodeids, // 12 + py::array_t nodes_treeids, // 13 + py::array_t nodes_truenodeids, // 14 + py::array_t nodes_values, // 15 const std::string& post_transform // 16 ); - py::tuple compute_cl(py::array_t X); - py::array_t compute_tree_outputs(py::array_t X); + py::tuple compute_cl(py::array_t X); + py::array_t compute_tree_outputs(py::array_t X); }; @@ -57,22 +57,22 @@ RuntimeTreeEnsembleClassifierP::~RuntimeTreeEnsembleClassifierP() { template void RuntimeTreeEnsembleClassifierP::init( - py::array_t base_values, // 0 - py::array_t class_ids, // 1 - py::array_t class_nodeids, // 2 - py::array_t class_treeids, // 3 - py::array_t class_weights, // 4 - py::array_t classlabels_int64s, // 5 + py::array_t base_values, // 0 + py::array_t class_ids, // 1 + py::array_t class_nodeids, // 2 + py::array_t class_treeids, // 3 + py::array_t class_weights, // 4 + py::array_t classlabels_int64s, // 5 const std::vector& classlabels_strings, // 6 - py::array_t nodes_falsenodeids, // 7 - py::array_t nodes_featureids, // 8 - py::array_t nodes_hitrates, // 9 - py::array_t nodes_missing_value_tracks_true, // 10 + py::array_t nodes_falsenodeids, // 7 + py::array_t nodes_featureids, // 8 + py::array_t nodes_hitrates, // 9 + py::array_t nodes_missing_value_tracks_true, // 10 const std::vector& nodes_modes, // 11 - py::array_t nodes_nodeids, // 12 - py::array_t nodes_treeids, // 13 - py::array_t nodes_truenodeids, // 14 - py::array_t nodes_values, // 15 + py::array_t nodes_nodeids, // 12 + py::array_t nodes_treeids, // 13 + py::array_t nodes_truenodeids, // 14 + py::array_t nodes_values, // 15 const std::string& post_transform // 16 ) { RuntimeTreeEnsembleCommonP::init( @@ -99,7 +99,7 @@ void RuntimeTreeEnsembleClassifierP::init( template -py::tuple RuntimeTreeEnsembleClassifierP::compute_cl(py::array_t X) { +py::tuple RuntimeTreeEnsembleClassifierP::compute_cl(py::array_t X) { return this->compute_cl_agg(X, _AggregatorClassifier( this->roots_.size(), this->n_targets_or_classes_, this->post_transform_, &(this->base_values_), @@ -109,7 +109,7 @@ py::tuple RuntimeTreeEnsembleClassifierP::compute_cl(py::array_t X template -py::array_t RuntimeTreeEnsembleClassifierP::compute_tree_outputs(py::array_t X) { +py::array_t RuntimeTreeEnsembleClassifierP::compute_tree_outputs(py::array_t X) { return this->compute_tree_outputs_agg(X, _AggregatorClassifier( this->roots_.size(), this->n_targets_or_classes_, this->post_transform_, &(this->base_values_), diff --git a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_common_p_.hpp b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_common_p_.hpp index a10021ea6..93ff7d651 100644 --- a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_common_p_.hpp +++ b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_common_p_.hpp @@ -62,22 +62,22 @@ class RuntimeTreeEnsembleCommonP { void init( const std::string &aggregate_function, - py::array_t base_values, + py::array_t base_values, int64_t n_targets_or_classes, - py::array_t nodes_falsenodeids, - py::array_t nodes_featureids, - py::array_t nodes_hitrates, - py::array_t nodes_missing_value_tracks_true, + py::array_t nodes_falsenodeids, + py::array_t nodes_featureids, + py::array_t nodes_hitrates, + py::array_t nodes_missing_value_tracks_true, const std::vector& nodes_modes, - py::array_t nodes_nodeids, - py::array_t nodes_treeids, - py::array_t nodes_truenodeids, - py::array_t nodes_values, + py::array_t nodes_nodeids, + py::array_t nodes_treeids, + py::array_t nodes_truenodeids, + py::array_t nodes_values, const std::string& post_transform, - py::array_t target_class_ids, - py::array_t target_class_nodeids, - py::array_t target_class_treeids, - py::array_t target_class_weights); + py::array_t target_class_ids, + py::array_t target_class_nodeids, + py::array_t target_class_treeids, + py::array_t target_class_weights); void init_c( const std::string &aggregate_function, @@ -109,30 +109,35 @@ class RuntimeTreeEnsembleCommonP { int64_t get_sizeof(); template - py::array_t compute_tree_outputs_agg(py::array_t X, const AGG &agg) const; + py::array_t compute_tree_outputs_agg(py::array_t X, const AGG &agg) const; - py::array_t debug_threshold(py::array_t values) const; + py::array_t debug_threshold(py::array_t values) const; // The two following methods uses buffers to avoid // spending time allocating buffers. As a consequence, // These methods are not thread-safe. template - py::array_t compute_agg(py::array_t X, const AGG &agg); + py::array_t compute_agg(py::array_t X, const AGG &agg); template - py::tuple compute_cl_agg(py::array_t X, const AGG &agg); + py::tuple compute_cl_agg(py::array_t X, const AGG &agg); private: template void compute_gil_free(const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Z, - py::array_t* Y, const AGG &agg); + const py::array_t& X, + py::array_t& Z, + py::array_t* Y, + const AGG &agg); template - void compute_gil_free_array_structure(const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Z, - py::array_t* Y, const AGG &agg); + void compute_gil_free_array_structure(const std::vector& x_dims, + int64_t N, int64_t stride, + const py::array_t& X, + py::array_t& Z, + py::array_t* Y, + const AGG &agg); void switch_to_array_structure(); }; @@ -185,22 +190,22 @@ int64_t RuntimeTreeEnsembleCommonP::get_sizeof() { template void RuntimeTreeEnsembleCommonP::init( const std::string &aggregate_function, - py::array_t base_values, + py::array_t base_values, int64_t n_targets_or_classes, - py::array_t nodes_falsenodeids, - py::array_t nodes_featureids, - py::array_t nodes_hitrates, - py::array_t nodes_missing_value_tracks_true, + py::array_t nodes_falsenodeids, + py::array_t nodes_featureids, + py::array_t nodes_hitrates, + py::array_t nodes_missing_value_tracks_true, const std::vector& nodes_modes, - py::array_t nodes_nodeids, - py::array_t nodes_treeids, - py::array_t nodes_truenodeids, - py::array_t nodes_values, + py::array_t nodes_nodeids, + py::array_t nodes_treeids, + py::array_t nodes_truenodeids, + py::array_t nodes_values, const std::string& post_transform, - py::array_t target_class_ids, - py::array_t target_class_nodeids, - py::array_t target_class_treeids, - py::array_t target_class_weights) { + py::array_t target_class_ids, + py::array_t target_class_nodeids, + py::array_t target_class_treeids, + py::array_t target_class_weights) { std::vector cbasevalues; array2vector(cbasevalues, base_values, NTYPE); @@ -485,7 +490,8 @@ std::vector RuntimeTreeEnsembleCommonP::get_nodes_modes() co template template -py::array_t RuntimeTreeEnsembleCommonP::compute_agg(py::array_t X, const AGG &agg) { +py::array_t RuntimeTreeEnsembleCommonP::compute_agg( + py::array_t X, const AGG &agg) { std::vector x_dims; arrayshape2vector(x_dims, X); if (x_dims.size() != 2) @@ -496,7 +502,7 @@ py::array_t RuntimeTreeEnsembleCommonP::compute_agg(py::array_t Z(x_dims[0] * n_targets_or_classes_); + py::array_t Z(x_dims[0] * n_targets_or_classes_); { py::gil_scoped_release release; @@ -511,7 +517,7 @@ py::array_t RuntimeTreeEnsembleCommonP::compute_agg(py::array_t template py::tuple RuntimeTreeEnsembleCommonP::compute_cl_agg( - py::array_t X, const AGG &agg) { + py::array_t X, const AGG &agg) { std::vector x_dims; arrayshape2vector(x_dims, X); if (x_dims.size() != 2) @@ -524,8 +530,8 @@ py::tuple RuntimeTreeEnsembleCommonP::compute_cl_agg( // Tensor* Y = context->Output(0, TensorShape({N})); // auto* Z = context->Output(1, TensorShape({N, class_count_})); - py::array_t Z(x_dims[0] * n_targets_or_classes_); - py::array_t Y(x_dims[0]); + py::array_t Z(x_dims[0] * n_targets_or_classes_); + py::array_t Y(x_dims[0]); { py::gil_scoped_release release; @@ -538,17 +544,17 @@ py::tuple RuntimeTreeEnsembleCommonP::compute_cl_agg( } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } @@ -556,8 +562,10 @@ py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array template template void RuntimeTreeEnsembleCommonP::compute_gil_free( const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Z, - py::array_t* Y, const AGG &agg) { + const py::array_t& X, + py::array_t& Z, + py::array_t* Y, + const AGG &agg) { // expected primary-expression before ')' token auto Z_ = _mutable_unchecked1(Z); // Z.mutable_unchecked<(size_t)1>(); @@ -731,8 +739,10 @@ void RuntimeTreeEnsembleCommonP::compute_gil_free( template template void RuntimeTreeEnsembleCommonP::compute_gil_free_array_structure( const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Z, - py::array_t* Y, const AGG &agg) { + const py::array_t& X, + py::array_t& Z, + py::array_t* Y, + const AGG &agg) { // expected primary-expression before ')' token auto Z_ = _mutable_unchecked1(Z); // Z.mutable_unchecked<(size_t)1>(); @@ -1226,7 +1236,7 @@ size_t RuntimeTreeEnsembleCommonP::ProcessTreeNodeLeave( template py::array_t RuntimeTreeEnsembleCommonP::debug_threshold( - py::array_t values) const { + py::array_t values) const { if (array_structure_) throw std::invalid_argument("debug_threshold not implemented if array_structure is true."); std::vector result(values.size() * n_nodes_); @@ -1254,7 +1264,8 @@ py::array_t RuntimeTreeEnsembleCommonP::debug_threshold( template template -py::array_t RuntimeTreeEnsembleCommonP::compute_tree_outputs_agg(py::array_t X, const AGG &agg) const { +py::array_t RuntimeTreeEnsembleCommonP::compute_tree_outputs_agg( + py::array_t X, const AGG &agg) const { if (array_structure_) throw std::invalid_argument("compute_tree_outputs_agg not implemented if array_structure is true."); diff --git a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor_.cpp b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor_.cpp index bd522f549..26bc9547b 100644 --- a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor_.cpp @@ -65,25 +65,25 @@ class RuntimeTreeEnsembleRegressor void init( const std::string &aggregate_function, - py::array_t base_values, + py::array_t base_values, int64_t n_targets, - py::array_t nodes_falsenodeids, - py::array_t nodes_featureids, - py::array_t nodes_hitrates, - py::array_t nodes_missing_value_tracks_true, + py::array_t nodes_falsenodeids, + py::array_t nodes_featureids, + py::array_t nodes_hitrates, + py::array_t nodes_missing_value_tracks_true, const std::vector& nodes_modes, - py::array_t nodes_nodeids, - py::array_t nodes_treeids, - py::array_t nodes_truenodeids, - py::array_t nodes_values, + py::array_t nodes_nodeids, + py::array_t nodes_treeids, + py::array_t nodes_truenodeids, + py::array_t nodes_values, const std::string& post_transform, - py::array_t target_ids, - py::array_t target_nodeids, - py::array_t target_treeids, - py::array_t target_weights + py::array_t target_ids, + py::array_t target_nodeids, + py::array_t target_treeids, + py::array_t target_weights ); - py::array_t compute(py::array_t X) const; + py::array_t compute(py::array_t X) const; void ProcessTreeNode(NTYPE* predictions, int64_t treeindex, const NTYPE* x_data, int64_t feature_base, @@ -93,16 +93,17 @@ class RuntimeTreeEnsembleRegressor int omp_get_max_threads(); - py::array_t debug_threshold(py::array_t values) const; + py::array_t debug_threshold(py::array_t values) const; - py::array_t compute_tree_outputs(py::array_t values) const; + py::array_t compute_tree_outputs(py::array_t values) const; private: void Initialize(); void compute_gil_free(const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Z) const; + const py::array_t& X, + py::array_t& Z) const; }; @@ -139,22 +140,22 @@ int RuntimeTreeEnsembleRegressor::omp_get_max_threads() { template void RuntimeTreeEnsembleRegressor::init( const std::string &aggregate_function, - py::array_t base_values, + py::array_t base_values, int64_t n_targets, - py::array_t nodes_falsenodeids, - py::array_t nodes_featureids, - py::array_t nodes_hitrates, - py::array_t nodes_missing_value_tracks_true, + py::array_t nodes_falsenodeids, + py::array_t nodes_featureids, + py::array_t nodes_hitrates, + py::array_t nodes_missing_value_tracks_true, const std::vector& nodes_modes, - py::array_t nodes_nodeids, - py::array_t nodes_treeids, - py::array_t nodes_truenodeids, - py::array_t nodes_values, + py::array_t nodes_nodeids, + py::array_t nodes_treeids, + py::array_t nodes_truenodeids, + py::array_t nodes_values, const std::string& post_transform, - py::array_t target_ids, - py::array_t target_nodeids, - py::array_t target_treeids, - py::array_t target_weights + py::array_t target_ids, + py::array_t target_nodeids, + py::array_t target_treeids, + py::array_t target_weights ) { aggregate_function_ = to_AGGREGATE_FUNCTION(aggregate_function); array2vector(base_values_, base_values, NTYPE); @@ -314,7 +315,7 @@ void RuntimeTreeEnsembleRegressor::Initialize() { template -py::array_t RuntimeTreeEnsembleRegressor::compute(py::array_t X) const { +py::array_t RuntimeTreeEnsembleRegressor::compute(py::array_t X) const { // const Tensor& X = *context->Input(0); // const TensorShape& x_shape = X.Shape(); std::vector x_dims; @@ -329,7 +330,7 @@ py::array_t RuntimeTreeEnsembleRegressor::compute(py::array_tOutput(0, TensorShape({N})); // auto* Z = context->Output(1, TensorShape({N, class_count_})); - py::array_t Z(x_dims[0] * n_targets_); + py::array_t Z(x_dims[0] * n_targets_); { py::gil_scoped_release release; @@ -339,12 +340,12 @@ py::array_t RuntimeTreeEnsembleRegressor::compute(py::array_t _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } -py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { +py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array_t& Z) { return Z.mutable_unchecked<1>(); } @@ -352,7 +353,8 @@ py::detail::unchecked_mutable_reference _mutable_unchecked1(py::array template void RuntimeTreeEnsembleRegressor::compute_gil_free( const std::vector& x_dims, int64_t N, int64_t stride, - const py::array_t& X, py::array_t& Z) const { + const py::array_t& X, + py::array_t& Z) const { // expected primary-expression before ')' token auto Z_ = _mutable_unchecked1(Z); // Z.mutable_unchecked<(size_t)1>(); @@ -662,7 +664,7 @@ void RuntimeTreeEnsembleRegressor::ProcessTreeNode( template -py::array_t RuntimeTreeEnsembleRegressor::debug_threshold(py::array_t values) const { +py::array_t RuntimeTreeEnsembleRegressor::debug_threshold(py::array_t values) const { std::vector result(values.size() * nodes_values_.size()); const NTYPE* x_data = values.data(0); const NTYPE* end = x_data + values.size(); @@ -687,7 +689,7 @@ py::array_t RuntimeTreeEnsembleRegressor::debug_threshold(py::array_ template -py::array_t RuntimeTreeEnsembleRegressor::compute_tree_outputs(py::array_t X) const { +py::array_t RuntimeTreeEnsembleRegressor::compute_tree_outputs(py::array_t X) const { std::vector x_dims; arrayshape2vector(x_dims, X); diff --git a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor_p_.cpp b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor_p_.cpp index 845e69b0d..fed1812d3 100644 --- a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor_p_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor_p_.cpp @@ -13,25 +13,25 @@ class RuntimeTreeEnsembleRegressorP : public RuntimeTreeEnsembleCommonP { void init( const std::string &aggregate_function, - py::array_t base_values, + py::array_t base_values, int64_t n_targets, - py::array_t nodes_falsenodeids, - py::array_t nodes_featureids, - py::array_t nodes_hitrates, - py::array_t nodes_missing_value_tracks_true, + py::array_t nodes_falsenodeids, + py::array_t nodes_featureids, + py::array_t nodes_hitrates, + py::array_t nodes_missing_value_tracks_true, const std::vector& nodes_modes, - py::array_t nodes_nodeids, - py::array_t nodes_treeids, - py::array_t nodes_truenodeids, - py::array_t nodes_values, + py::array_t nodes_nodeids, + py::array_t nodes_treeids, + py::array_t nodes_truenodeids, + py::array_t nodes_values, const std::string& post_transform, - py::array_t target_ids, - py::array_t target_nodeids, - py::array_t target_treeids, - py::array_t target_weights); + py::array_t target_ids, + py::array_t target_nodeids, + py::array_t target_treeids, + py::array_t target_weights); - py::array_t compute(py::array_t X); - py::array_t compute_tree_outputs(py::array_t X); + py::array_t compute(py::array_t X); + py::array_t compute_tree_outputs(py::array_t X); }; @@ -50,22 +50,22 @@ RuntimeTreeEnsembleRegressorP::~RuntimeTreeEnsembleRegressorP() { template void RuntimeTreeEnsembleRegressorP::init( const std::string &aggregate_function, - py::array_t base_values, + py::array_t base_values, int64_t n_targets, - py::array_t nodes_falsenodeids, - py::array_t nodes_featureids, - py::array_t nodes_hitrates, - py::array_t nodes_missing_value_tracks_true, + py::array_t nodes_falsenodeids, + py::array_t nodes_featureids, + py::array_t nodes_hitrates, + py::array_t nodes_missing_value_tracks_true, const std::vector& nodes_modes, - py::array_t nodes_nodeids, - py::array_t nodes_treeids, - py::array_t nodes_truenodeids, - py::array_t nodes_values, + py::array_t nodes_nodeids, + py::array_t nodes_treeids, + py::array_t nodes_truenodeids, + py::array_t nodes_values, const std::string& post_transform, - py::array_t target_ids, - py::array_t target_nodeids, - py::array_t target_treeids, - py::array_t target_weights) { + py::array_t target_ids, + py::array_t target_nodeids, + py::array_t target_treeids, + py::array_t target_weights) { RuntimeTreeEnsembleCommonP::init( aggregate_function, base_values, n_targets, nodes_falsenodeids, nodes_featureids, nodes_hitrates, @@ -77,7 +77,8 @@ void RuntimeTreeEnsembleRegressorP::init( template -py::array_t RuntimeTreeEnsembleRegressorP::compute(py::array_t X) { +py::array_t RuntimeTreeEnsembleRegressorP::compute( + py::array_t X) { switch(this->aggregate_function_) { case AGGREGATE_FUNCTION::AVERAGE: return this->compute_agg(X, _AggregatorAverage( @@ -101,7 +102,8 @@ py::array_t RuntimeTreeEnsembleRegressorP::compute(py::array_t -py::array_t RuntimeTreeEnsembleRegressorP::compute_tree_outputs(py::array_t X) { +py::array_t RuntimeTreeEnsembleRegressorP::compute_tree_outputs( + py::array_t X) { switch(this->aggregate_function_) { case AGGREGATE_FUNCTION::AVERAGE: return this->compute_tree_outputs_agg(X, _AggregatorAverage( diff --git a/mlprodict/tools/graphs.py b/mlprodict/tools/graphs.py index 5a198f884..83bce64c1 100644 --- a/mlprodict/tools/graphs.py +++ b/mlprodict/tools/graphs.py @@ -10,6 +10,16 @@ import onnx +def make_hash_bytes(data, length=20): + """ + Creates a hash of length *length*. + """ + m = hashlib.sha256() + m.update(data) + res = m.hexdigest()[:length] + return res + + class AdjacencyGraphDisplay: """ Structure which contains the necessary information to @@ -146,6 +156,20 @@ def __init__(self, kind): def __repr__(self): return "A(%r)" % self.kind + class B: + "Additional information for a vertex or an edge." + + def __init__(self, name, content, onnx_name): + if not isinstance(content, str): + raise TypeError( # pragma: no cover + "content must be str not %r." % type(content)) + self.name = name + self.content = content + self.onnx_name = onnx_name + + def __repr__(self): + return "B(%r, %r, %r)" % (self.name, self.content, self.onnx_name) + def __init__(self, v0, v1, edges): """ :param v0: first set of vertices (dictionary) @@ -328,16 +352,227 @@ def adjust(c, way): return graph + def order(self): + """ + Order nodes. Depth first. + Returns a sequence of keys of mixed *v1*, *v2*. + """ + # Creates forwards nodes. + forwards = {} + backwards = {} + for k in self.v0: + forwards[k] = [] + backwards[k] = [] + for k in self.v1: + forwards[k] = [] + backwards[k] = [] + modif = True + while modif: + modif = False + for edge in self.edges: + a, b = edge + if b not in forwards[a]: + forwards[a].append(b) + modif = True + if a not in backwards[b]: + backwards[b].append(a) + modif = True + + # roots + roots = [b for b, backs in backwards.items() if len(backs) == 0] + if len(roots) == 0: + raise RuntimeError( # pragma: no cover + "This graph has cycles. Not allowed.") + + # ordering + order = {} + stack = roots + while len(stack) > 0: + node = stack.pop() + order[node] = len(order) + w = forwards[node] + if len(w) == 0: + continue + last = w.pop() + stack.append(last) -def onnx2bigraph(model_onnx, recursive=False): + return order + + def summarize(self): + """ + Creates a text summary of the graph. + """ + order = self.order() + keys = [(o, k) for k, o in order.items()] + keys.sort() + + rows = [] + for _, k in keys: + if k in self.v1: + rows.append(str(self.v1[k])) + return "\n".join(rows) + + @staticmethod + def _onnx2bigraph_basic(model_onnx, recursive=False): + """ + Implements graph type `'basic'` for function + @see fn onnx2bigraph. + """ + + if recursive: + raise NotImplementedError( # pragma: no cover + "Option recursive=True is not implemented yet.") + v0 = {} + v1 = {} + edges = {} + + # inputs + for i, o in enumerate(model_onnx.graph.input): + v0[o.name] = BiGraph.A('Input-%d' % i) + for i, o in enumerate(model_onnx.graph.output): + v0[o.name] = BiGraph.A('Output-%d' % i) + for o in model_onnx.graph.initializer: + v0[o.name] = BiGraph.A('Init') + for n in model_onnx.graph.node: + nname = n.name if len(n.name) > 0 else "id%d" % id(n) + v1[nname] = BiGraph.A(n.op_type) + for i, o in enumerate(n.input): + c = str(i) if i < 10 else "+" + nname = n.name if len(n.name) > 0 else "id%d" % id(n) + edges[o, nname] = BiGraph.A('I%s' % c) + for i, o in enumerate(n.output): + c = str(i) if i < 10 else "+" + if o not in v0: + v0[o] = BiGraph.A('inout') + nname = n.name if len(n.name) > 0 else "id%d" % id(n) + edges[nname, o] = BiGraph.A('O%s' % c) + + return BiGraph(v0, v1, edges) + + @staticmethod + def _onnx2bigraph_simplified(model_onnx, recursive=False): + """ + Implements graph type `'simplified'` for function + @see fn onnx2bigraph. + """ + if recursive: + raise NotImplementedError( # pragma: no cover + "Option recursive=True is not implemented yet.") + v0 = {} + v1 = {} + edges = {} + + # inputs + for o in model_onnx.graph.input: + v0["I%d" % len(v0)] = BiGraph.B( + 'In', make_hash_bytes(o.type.SerializeToString(), 2), o.name) + for o in model_onnx.graph.output: + v0["O%d" % len(v0)] = BiGraph.B( + 'Ou', make_hash_bytes(o.type.SerializeToString(), 2), o.name) + for o in model_onnx.graph.initializer: + v0["C%d" % len(v0)] = BiGraph.B( + 'Cs', make_hash_bytes(o.raw_data, 10), o.name) + + names_v0 = {v.onnx_name: k for k, v in v0.items()} + + for n in model_onnx.graph.node: + key_node = "N%d" % len(v1) + if len(n.attribute) > 0: + ats = [] + for at in n.attribute: + ats.append(at.SerializeToString()) + ct = make_hash_bytes(b"".join(ats), 10) + else: + ct = "" + v1[key_node] = BiGraph.B( + n.op_type, ct, n.name) + for o in n.input: + key_in = names_v0[o] + edges[key_in, key_node] = BiGraph.A('I') + for o in n.output: + if o not in names_v0: + key = "R%d" % len(v0) + v0[key] = BiGraph.B('Re', n.op_type, o) + names_v0[o] = key + edges[key_node, key] = BiGraph.A('O') + + return BiGraph(v0, v1, edges) + + @staticmethod + def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print): + """ + Computes a distance between two ONNX graphs. They must not + be too big otherwise this function might take for ever. + The function relies on package :epkg:`mlstatpy`. + + :param onx1: first graph (ONNX graph or model file name) + :param onx2: second graph (ONNX graph or model file name) + :param verbose: verbosity + :param fLOG: logging function + :return: distance and differences + + .. warning:: + + This is very experimental and very slow. + + .. versionadded:: 0.7 + """ + from mlstatpy.graph.graph_distance import GraphDistance + + if isinstance(onx1, str): + onx1 = onnx.load(onx1) + if isinstance(onx2, str): + onx2 = onnx.load(onx2) + + def make_hash(init): + return make_hash_bytes(init.raw_data) + + def build_graph(onx): + edges = [] + labels = {} + for node in onx.graph.node: + if len(node.name) == 0: + name = str(id(node)) + else: + name = node.name + for i in node.input: + edges.append((i, name)) + for p, i in enumerate(node.output): + edges.append((name, i)) + labels[i] = "%s:%d" % (node.op_type, p) + labels[name] = node.op_type + for init in onx.graph.initializer: + labels[init.name] = make_hash(init) + + g = GraphDistance(edges, vertex_label=labels) + return g + + g1 = build_graph(onx1) + g2 = build_graph(onx2) + + dist, gdist = g1.distance_matching_graphs_paths( + g2, verbose=verbose, fLOG=fLOG, use_min=False) + return dist, gdist + + +def onnx2bigraph(model_onnx, recursive=False, graph_type='basic'): """ Converts an ONNX graph into a graph representation, edges, vertices. :param model_onnx: ONNX graph :param recursive: dig into subgraphs too + :param graph_type: kind of graph it creates :return: see @cl BiGraph + About *graph_type*: + + * `'basic'`: basic graph structure, it returns an instance + of type @see cl BiGraph. The structure keeps the original + names. + * `'simplified'`: simplifed graph structure, names are removed + as they could be prevent the algorithm to find any matching. + .. exref:: :title: Displays an ONNX graph as text @@ -369,35 +604,14 @@ def onnx2bigraph(model_onnx, recursive=False): .. versionadded:: 0.7 """ - if recursive: - raise NotImplementedError( # pragma: no cover - "Option recursive=True is not implemented yet.") - v0 = {} - v1 = {} - edges = {} - - # inputs - for i, o in enumerate(model_onnx.graph.input): - v0[o.name] = BiGraph.A('Input-%d' % i) - for i, o in enumerate(model_onnx.graph.output): - v0[o.name] = BiGraph.A('Output-%d' % i) - for o in model_onnx.graph.initializer: - v0[o.name] = BiGraph.A('Init') - for n in model_onnx.graph.node: - nname = n.name if len(n.name) > 0 else "id%d" % id(n) - v1[nname] = BiGraph.A(n.op_type) - for i, o in enumerate(n.input): - c = str(i) if i < 10 else "+" - nname = n.name if len(n.name) > 0 else "id%d" % id(n) - edges[o, nname] = BiGraph.A('I%s' % c) - for i, o in enumerate(n.output): - c = str(i) if i < 10 else "+" - if o not in v0: - v0[o] = BiGraph.A('inout') - nname = n.name if len(n.name) > 0 else "id%d" % id(n) - edges[nname, o] = BiGraph.A('O%s' % c) - - return BiGraph(v0, v1, edges) + if graph_type == 'basic': + return BiGraph._onnx2bigraph_basic( + model_onnx, recursive=recursive) + if graph_type == 'simplified': + return BiGraph._onnx2bigraph_simplified( + model_onnx, recursive=recursive) + raise ValueError( + "Unknown value for graph_type=%r." % graph_type) def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print): @@ -406,50 +620,16 @@ def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print): be too big otherwise this function might take for ever. The function relies on package :epkg:`mlstatpy`. - :param onx1: first graph (ONNX graph or model file name) :param onx2: second graph (ONNX graph or model file name) :param verbose: verbosity :param fLOG: logging function :return: distance and differences + .. warning:: + + This is very experimental and very slow. + .. versionadded:: 0.7 """ - from mlstatpy.graph.graph_distance import GraphDistance - - if isinstance(onx1, str): - onx1 = onnx.load(onx1) - if isinstance(onx2, str): - onx2 = onnx.load(onx2) - - def make_hash(init): - m = hashlib.sha256() - m.update(init.raw_data) - return m.hexdigest()[:20] - - def build_graph(onx): - edges = [] - labels = {} - for node in onx.graph.node: - if len(node.name) == 0: - name = str(id(node)) - else: - name = node.name - for i in node.input: - edges.append((i, name)) - for p, i in enumerate(node.output): - edges.append((name, i)) - labels[i] = "%s:%d" % (node.op_type, p) - labels[name] = node.op_type - for init in onx.graph.initializer: - labels[init.name] = make_hash(init) - - g = GraphDistance(edges, vertex_label=labels) - return g - - g1 = build_graph(onx1) - g2 = build_graph(onx2) - - dist, gdist = g1.distance_matching_graphs_paths( - g2, verbose=verbose, fLOG=fLOG, use_min=False) - return dist, gdist + return BiGraph.onnx_graph_distance(onx1, onx2, verbose=verbose, fLOG=fLOG) diff --git a/requirements.txt b/requirements.txt index aca6059db..433c5e2a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,4 +48,4 @@ xgboost # onnx onnx==1.9.0 onnxruntime>=1.8.0 -skl2onnx>=1.9.1 +skl2onnx>=1.9.2 diff --git a/setup.py b/setup.py index e201abb3e..64fbd3aea 100644 --- a/setup.py +++ b/setup.py @@ -383,14 +383,14 @@ def get_extensions(): install_requires=["pybind11", "numpy>=1.17", "onnx>=1.7.0", 'scipy>=1.0.0', 'jinja2', 'cython'], extras_require={ - 'npy': ['scikit-learn>=0.24', 'skl2onnx>=1.9'], - 'onnx_conv': ['scikit-learn>=0.24', 'skl2onnx>=1.9.1', + 'npy': ['scikit-learn>=0.24', 'skl2onnx>=1.9.2'], + 'onnx_conv': ['scikit-learn>=0.24', 'skl2onnx>=1.9.2', 'joblib', 'threadpoolctl', 'mlinsights>=0.3', 'lightgbm', 'xgboost'], - 'onnx_val': ['scikit-learn>=0.24', 'skl2onnx>=1.9.1', + 'onnx_val': ['scikit-learn>=0.24', 'skl2onnx>=1.9.2', 'onnxruntime>=1.6.0', 'joblib', 'threadpoolctl'], 'sklapi': ['scikit-learn>=0.24', 'joblib', 'threadpoolctl'], - 'all': ['scikit-learn>=0.24', 'skl2onnx>=1.9.1', + 'all': ['scikit-learn>=0.24', 'skl2onnx>=1.9.2', 'onnxruntime>=1.6.0', 'scipy' 'joblib', 'pandas', 'threadpoolctl', 'mlinsights>=0.3', 'lightgbm', 'xgboost', 'mlstatpy>=0.3.593'],