\n",
- "0 | 0 | def dft_real(x, fft_length=None, transpose=True): | def dft_real_d3(x, fft_length=None, transpose=True): |
\n",
- "1 | 1 | if len(x.shape) == 1: | if len(x.shape) != 3: |
\n",
- "2 | 2 | x = x.reshape((1, -1)) | raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape) |
\n",
- "3 | | N = 1 | |
\n",
- "4 | | else: | |
\n",
- "5 | 3 | N = x.shape[0] | N = x.shape[1] |
\n",
- "6 | 4 | C = x.shape[-1] if transpose else x.shape[-2] | C = x.shape[-1] if transpose else x.shape[-2] |
\n",
- "7 | 5 | if fft_length is None: | if fft_length is None: |
\n",
- "8 | 6 | fft_length = x.shape[-1] | fft_length = x.shape[-1] |
\n",
- "9 | 7 | size = fft_length // 2 + 1 | size = fft_length // 2 + 1 |
\n",
- "10 | 8 | | |
\n",
- "11 | 9 | cst = dft_real_cst(C, fft_length) | cst = dft_real_cst(C, fft_length) |
\n",
- "12 | 10 | if transpose: | if transpose: |
\n",
- "13 | 11 | x = numpy.transpose(x, (1, 0)) | x = numpy.transpose(x, (0, 2, 1)) |
\n",
- "14 | 12 | a = cst[:, :, :fft_length] | a = cst[:, :, :fft_length] |
\n",
- "15 | 13 | b = x[:fft_length] | b = x[:, :fft_length, :] |
\n",
- " | 14 | | a = numpy.expand_dims(a, 0) |
\n",
- " | 15 | | b = numpy.expand_dims(b, 1) |
\n",
- "16 | 16 | res = numpy.matmul(a, b) | res = numpy.matmul(a, b) |
\n",
- "17 | 17 | res = res[:, :size, :] | res = res[:, :, :size, :] |
\n",
- "18 | 18 | return numpy.transpose(res, (0, 2, 1)) | return numpy.transpose(res, (1, 0, 3, 2)) |
\n",
- "19 | 19 | else: | else: |
\n",
- "20 | 20 | a = cst[:, :, :fft_length] | a = cst[:, :, :fft_length] |
\n",
- "21 | 21 | b = x[:fft_length] | b = x[:, :fft_length, :] |
\n",
- " | 22 | | a = numpy.expand_dims(a, 0) |
\n",
- " | 23 | | b = numpy.expand_dims(b, 1) |
\n",
- "22 | 24 | return numpy.matmul(a, b) | res = numpy.matmul(a, b) |
\n",
- " | 25 | | return numpy.transpose(res, (1, 0, 2, 3)) |
\n",
- "23 | 26 | | |
\n",
- "
"
+ "cell_type": "markdown",
+ "id": "caee1f84",
+ "metadata": {},
+ "source": [
+ "## FFT 2D in ONNX\n",
+ "\n",
+ "We use again the numpy API for ONNX."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "ca641274",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def onnx_rfft_1d(x, fft_length=None, transpose=True):\n",
+ " if fft_length is None:\n",
+ " raise RuntimeError(\"fft_length must be specified.\")\n",
+ " \n",
+ " size = fft_length // 2 + 1\n",
+ " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n",
+ " if transpose:\n",
+ " xt = npnx.transpose(x, (1, 0))\n",
+ " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n",
+ " return npnx.transpose(res, (0, 2, 1))\n",
+ " else:\n",
+ " return npnx.matmul(cst[:, :, :fft_length], x[:fft_length])\n",
+ "\n",
+ "\n",
+ "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n",
+ "def onnx_rfft_2d(x, fft_length=None):\n",
+ " mat = x[:fft_length[0], :fft_length[1]]\n",
+ " \n",
+ " # first FFT\n",
+ " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n",
+ " \n",
+ " # second FFT decomposed on FFT on real part and imaginary part\n",
+ " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n",
+ " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n",
+ " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n",
+ " res = res2_real + res2_imag2\n",
+ " size = fft_length[1]//2 + 1\n",
+ " return res[:, :fft_length[0], :size]\n",
+ "\n",
+ "\n",
+ "fft2d_cus = fft2d(rnd, rnd.shape)\n",
+ "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n",
+ "almost_equal(fft2d_cus, fft2d_onx)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "20fcd8a9",
+ "metadata": {},
+ "source": [
+ "The corresponding ONNX graph."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "b1379b06",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "