# ONNX and FFT

ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?

In [1]:
from jyquickhelper import add_notebook_menu
add_notebook_menu()

In [2]:
%load_ext mlprodict

## Python implementation of RFFT

We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html).

In [3]:
import numpy


def almost_equal(a, b, error=1e-5):
    """
    The function compares two matrices, one may be complex. In that case,
    this matrix is changed into a new matrix with a new first dimension,
    [0,::] means real part, [1,::] means imaginary part.
    """
    if b.dtype in (numpy.complex64, numpy.complex128):
        return almost_equal(b, a, error)
    if a.dtype in (numpy.complex64, numpy.complex128):
        new_a = numpy.empty((2,) + a.shape).astype(b.dtype)
        new_a[0] = numpy.real(a) 
        new_a[1] = numpy.imag(a)
        return almost_equal(new_a, b)
    if a.shape != b.shape:
        raise AssertionError("Shape mismatch %r != %r." % (a.shape, b.shape))
    diff = numpy.abs(a.ravel() - b.ravel()).max()
    if diff > error:
        raise AssertionError("Mismatch max diff=%r > %r." % (diff, error))


def dft_real_cst(N, fft_length):
    n = numpy.arange(N)
    k = n.reshape((N, 1)).astype(numpy.float64)
    M = numpy.exp(-2j * numpy.pi * k * n / fft_length)
    both = numpy.empty((2,) + M.shape)
    both[0, :, :] = numpy.real(M)
    both[1, :, :] = numpy.imag(M)
    return both


def dft_real(x, fft_length=None):
    if len(x.shape) == 1:
        x = x.reshape((-1, 1))
    else:
        x = x.T
    if fft_length is None:
        fft_length = x.shape[0]
    size = fft_length // 2 + 1
    cst = dft_real_cst(x.shape[0], fft_length)
    res = numpy.dot(cst[:, :, :fft_length], x[:fft_length])[:, :size, :]
    return numpy.transpose(res, (0, 2, 1))


rnd = numpy.random.randn(3, 4).astype(numpy.float32)
fft_np = numpy.fft.rfft(rnd)
fft_cus = dft_real(rnd)
fft_np

array([[-1.95968306e+00+0.j        ,  2.80051970e+00+1.72947609j,
         3.50391865e-03+0.j        ],
       [ 4.89132166e-01+0.j        ,  1.36114281e+00+0.86927813j,
         5.75416410e+00+0.j        ],
       [-2.04808140e+00+0.j        ,  2.23222792e-01+0.40871835j,
        -1.47002804e+00+0.j        ]])

Function `almost_equal` verifies both functions return the same results.

In [4]:
almost_equal(fft_np, fft_cus)

Let's do the same with `fft_length < shape[1]`.

In [5]:
fft_np3 = numpy.fft.rfft(rnd, n=3)
fft_cus3 = dft_real(rnd, fft_length=3)
fft_np3

array([[-2.33362436+0.j        ,  2.53363478-0.46225825j],
       [ 1.37075108+0.j        ,  2.67671767+2.27864249j],
       [-2.10792723+0.j        , -0.09791033-0.55621888j]])

In [6]:
almost_equal(fft_np3, fft_cus3)

## What about ONNX

Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant.

In [7]:
from typing import Any
import mlprodict.npy.numpy_onnx_impl as npnx
from mlprodict.npy import onnxnumpy_np
from mlprodict.npy.onnx_numpy_annotation import NDArrayType
# from mlprodict.onnxrt import OnnxInference

@onnxnumpy_np(signature=NDArrayType(("T:all", ), dtypes_out=('T',)))
def onnx_rfft(x, fft_length=None):
    if fft_length is None:
        raise RuntimeError("fft_length must be specified.")
    
    size = fft_length // 2 + 1
    cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)
    xt = npnx.transpose(x, (1, 0))
    res = npnx.dot(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]
    return npnx.transpose(res, (0, 2, 1))

fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])
fft_onx

RuntimeError: module compiled against API version 0xe but this version of numpy is 0xd

SystemError: <built-in function __import__> returned a result with an error set

array([[[-1.9596831e+00,  2.8005197e+00,  3.5039186e-03],
        [ 4.8913223e-01,  1.3611429e+00,  5.7541642e+00],
        [-2.0480814e+00,  2.2322279e-01, -1.4700280e+00]],

       [[ 0.0000000e+00,  1.7294761e+00, -4.3412485e-16],
        [ 0.0000000e+00,  8.6927813e-01,  7.5392428e-16],
        [ 0.0000000e+00,  4.0871835e-01, -2.2202143e-16]]], dtype=float32)

In [8]:
almost_equal(fft_cus, fft_onx)

The corresponding ONNX graph is the following:

In [9]:
key = list(onnx_rfft.signed_compiled)[0]
%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_

In [10]:
fft_onx3 = onnx_rfft(rnd, fft_length=3)
almost_equal(fft_cus3, fft_onx3)

## FFT 2D

In [11]:
def _DFT_cst(N, fft_length, trunc=True):
    n = numpy.arange(N)
    k = n.reshape((N, 1)).astype(numpy.float64)
    M = numpy.exp(-2j * numpy.pi * k * n / fft_length)
    return M[:fft_length // 2 + 1] if trunc else M

def DFT(x, fft_length=None, axis=1):
    if axis == 1:
        x = x.T
    if fft_length is None:
        fft_length = x.shape[0]
    cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)
    if axis == 1:
        return numpy.dot(cst, x).T
    else:
        return numpy.dot(cst, x)

def fft2d(mat, fft_length):
    mat = mat[:fft_length[0], :fft_length[1]]
    res = mat.copy()
    res = DFT(res, fft_length[1], axis=1)
    res = DFT(res, fft_length[0], axis=0)
    return res[:fft_length[0], :fft_length[1]//2 + 1]


fft2d_np = numpy.fft.rfft2(rnd)
fft2d_cus = fft2d(rnd, rnd.shape)

In [12]:
fft2d_np

array([[-3.51863229+0.j        ,  4.38488531+3.00747257j,
         4.28763998+0.j        ],
       [-1.18020844-2.1972914j ,  2.40719338+0.10501021j,
        -2.13856411-6.25633392j],
       [-1.18020844+2.1972914j ,  1.60948043+2.0759455j ,
        -2.13856411+6.25633392j]])

In [13]:
fft2d_cus

array([[-3.51863229+0.00000000e+00j,  4.38488531+3.00747257e+00j,
         4.28763998+9.77780322e-17j],
       [-1.18020844-2.19729140e+00j,  2.40719338+1.05010207e-01j,
        -2.13856411-6.25633392e+00j],
       [-1.18020844+2.19729140e+00j,  1.60948043+2.07594550e+00j,
        -2.13856411+6.25633392e+00j]])