# ONNX and FFT

ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?

In [1]:
from jyquickhelper import add_notebook_menu
add_notebook_menu()

In [2]:
%load_ext mlprodict

In [3]:
import numpy
numpy.__version__

'1.21.0'

## Python implementation of RFFT

We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html).

In [4]:
import numpy


def almost_equal(a, b, error=1e-5):
    """
    The function compares two matrices, one may be complex. In that case,
    this matrix is changed into a new matrix with a new first dimension,
    [0,::] means real part, [1,::] means imaginary part.
    """
    if a.dtype in (numpy.complex64, numpy.complex128):
        dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32
        new_a = numpy.empty((2,) + a.shape).astype(dtype)
        new_a[0] = numpy.real(a)
        new_a[1] = numpy.imag(a)
        return almost_equal(new_a, b, error)
    if b.dtype in (numpy.complex64, numpy.complex128):
        return almost_equal(b, a, error)
    if a.shape != b.shape:
        raise AssertionError("Shape mismatch %r != %r." % (a.shape, b.shape))
    diff = numpy.abs(a.ravel() - b.ravel()).max()
    if diff > error:
        raise AssertionError("Mismatch max diff=%r > %r." % (diff, error))


def dft_real_cst(N, fft_length):
    n = numpy.arange(N)
    k = n.reshape((N, 1)).astype(numpy.float64)
    M = numpy.exp(-2j * numpy.pi * k * n / fft_length)
    both = numpy.empty((2,) + M.shape)
    both[0, :, :] = numpy.real(M)
    both[1, :, :] = numpy.imag(M)
    return both


def dft_real(x, fft_length=None, transpose=True):
    if len(x.shape) == 1:
        x = x.reshape((1, -1))
        N = 1
    else:
        N = x.shape[0]        
    C = x.shape[-1] if transpose else x.shape[-2]
    if fft_length is None:
        fft_length = x.shape[-1]
    size = fft_length // 2 + 1

    cst = dft_real_cst(C, fft_length)
    if transpose:
        x = numpy.transpose(x, (1, 0))
        res = numpy.dot(cst[:, :, :fft_length], x[:fft_length])[:, :size, :]
        return numpy.transpose(res, (0, 2, 1))
    else:
        return numpy.dot(cst[:, :, :fft_length], x[:fft_length])


rnd = numpy.random.randn(5, 7).astype(numpy.float32)
fft_np = numpy.fft.rfft(rnd)
fft_cus = dft_real(rnd)
fft_np

array([[ 0.92935219+0.j        ,  1.1166406 +0.18610885j,
         2.98881347-0.86137828j,  0.57062752-3.17075076j],
       [-0.81071034+0.j        ,  4.04571912+1.34415298j,
        -0.75316593+1.87375117j, -3.73972034+1.19963451j],
       [ 0.49893169+0.j        , -2.38853745+0.91784964j,
        -2.3230939 +2.42467461j,  2.84973582+0.96874118j],
       [-0.85518897+0.j        , -1.07457921+2.14618057j,
         0.67522719-2.17320735j,  1.31480887+2.2782433j ],
       [ 2.80867666+0.j        , -2.79453396-2.22901834j,
         0.492986  +0.10661537j,  2.65317564+0.57651319j]])

Function `almost_equal` verifies both functions return the same results.

In [5]:
almost_equal(fft_np, fft_cus)

Let's do the same with `fft_length < shape[1]`.

In [6]:
fft_np3 = numpy.fft.rfft(rnd, n=3)
fft_cus3 = dft_real(rnd, fft_length=3)
fft_np3

array([[ 0.58212829+0.j        ,  1.91211772-1.78320393j],
       [-0.3185378 +0.j        , -0.20609781-1.18129868j],
       [-0.81120646+0.j        , -0.28543806+3.05769342j],
       [-1.06384408+0.j        ,  0.74100591+0.43276681j],
       [ 1.77509081+0.j        , -0.13498855+1.82011058j]])

In [7]:
almost_equal(fft_np3, fft_cus3)

## RFFT in ONNX

Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant.

In [8]:
from typing import Any
import mlprodict.npy.numpy_onnx_impl as npnx
from mlprodict.npy import onnxnumpy_np
from mlprodict.npy.onnx_numpy_annotation import NDArrayType
# from mlprodict.onnxrt import OnnxInference

@onnxnumpy_np(signature=NDArrayType(("T:all", ), dtypes_out=('T',)))
def onnx_rfft(x, fft_length=None):
    if fft_length is None:
        raise RuntimeError("fft_length must be specified.")
    
    size = fft_length // 2 + 1
    cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)
    xt = npnx.transpose(x, (1, 0))
    res = npnx.dot(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]
    return npnx.transpose(res, (0, 2, 1))

fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])
fft_onx

array([[[ 0.92935216,  1.1166406 ,  2.9888134 ,  0.5706275 ],
        [-0.81071043,  4.045719  , -0.7531659 , -3.7397203 ],
        [ 0.4989317 , -2.3885374 , -2.323094  ,  2.849736  ],
        [-0.85518885, -1.0745792 ,  0.6752271 ,  1.3148088 ],
        [ 2.8086765 , -2.794534  ,  0.49298596,  2.6531756 ]],

       [[ 0.        ,  0.18610872, -0.8613782 , -3.1707506 ],
        [ 0.        ,  1.3441529 ,  1.8737512 ,  1.1996344 ],
        [ 0.        ,  0.9178499 ,  2.4246747 ,  0.96874106],
        [ 0.        ,  2.1461806 , -2.1732073 ,  2.2782433 ],
        [ 0.        , -2.2290184 ,  0.10661539,  0.5765133 ]]],
      dtype=float32)

In [9]:
almost_equal(fft_cus, fft_onx)

The corresponding ONNX graph is the following:

In [10]:
key = list(onnx_rfft.signed_compiled)[0]
%onnxview onnx_rfft.signed_compiled[key].compiled.onnx_

In [11]:
fft_onx3 = onnx_rfft(rnd, fft_length=3)
almost_equal(fft_cus3, fft_onx3)

## FFT 2D

Below the code for complex features.

In [12]:
def _DFT_cst(N, fft_length, trunc=True):
    n = numpy.arange(N)
    k = n.reshape((N, 1)).astype(numpy.float64)
    M = numpy.exp(-2j * numpy.pi * k * n / fft_length)
    return M[:fft_length // 2 + 1] if trunc else M

def DFT(x, fft_length=None, axis=1):
    if axis == 1:
        x = x.T
    if fft_length is None:
        fft_length = x.shape[0]
    cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)
    if axis == 1:
        return numpy.dot(cst, x).T
    return numpy.dot(cst, x)

def fft2d_(mat, fft_length):
    mat = mat[:fft_length[0], :fft_length[1]]
    res = mat.copy()
    res = DFT(res, fft_length[1], axis=1)
    res = DFT(res, fft_length[0], axis=0)
    return res[:fft_length[0], :fft_length[1]//2 + 1]


rnd = numpy.random.randn(5, 7).astype(numpy.float32)
fft2d_np_ = fft2d_(rnd, rnd.shape)
fft2d_np = numpy.fft.rfft2(rnd)
fft2d_np_

array([[-2.036582  +0.j        , -0.85992725+6.47780438j,
        -3.99332006-3.11192536j, -1.32368431-2.48821071j],
       [ 4.37345155-7.03173815j, -3.14890126+1.59632335j,
        -3.75306979+0.66651699j, -1.56716114+4.75028368j],
       [ 2.76767016+1.25297955j,  5.07926144-2.23393831j,
         2.41908275-8.55451105j, -8.84556476-1.29356088j],
       [ 2.76767016-1.25297955j,  2.41782872-4.44962381j,
        -3.6501426 +4.13120322j,  4.30875103+0.96179243j],
       [ 4.37345155+7.03173815j,  1.77135529+4.4385736j ,
         2.40878105+5.40109054j, -1.65462983+0.2149866j ]])

In [13]:
almost_equal(fft2d_np_, fft2d_np)

It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):

* $y = FFT(x, axis=1)$
* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$
* $z = z_r + i z_i$

*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that.

In [14]:
def fft2d(mat, fft_length):
    mat = mat[:fft_length[0], :fft_length[1]]
    res = mat.copy()
    
    # first FFT
    res = dft_real(res, fft_length=fft_length[1], transpose=True)
    
    # second FFT decomposed on FFT on real part and imaginary part
    res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)
    res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False)    
    res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])
    res = res2_real + res2_imag2
    size = fft_length[1]//2 + 1
    return res[:, :fft_length[0], :size]


fft2d_np = numpy.fft.rfft2(rnd)
fft2d_cus = fft2d(rnd, rnd.shape)
almost_equal(fft2d_np, fft2d_cus)

In [15]:
fft2d_np

array([[-2.036582  +0.j        , -0.85992725+6.47780438j,
        -3.99332006-3.11192536j, -1.32368431-2.48821071j],
       [ 4.37345155-7.03173815j, -3.14890126+1.59632335j,
        -3.75306979+0.66651699j, -1.56716114+4.75028368j],
       [ 2.76767016+1.25297955j,  5.07926144-2.23393831j,
         2.41908275-8.55451105j, -8.84556476-1.29356088j],
       [ 2.76767016-1.25297955j,  2.41782872-4.44962381j,
        -3.6501426 +4.13120322j,  4.30875103+0.96179243j],
       [ 4.37345155+7.03173815j,  1.77135529+4.4385736j ,
         2.40878105+5.40109054j, -1.65462983+0.2149866j ]])

In [16]:
fft2d_cus

array([[[-2.036582  , -0.85992725, -3.99332006, -1.32368431],
        [ 4.37345155, -3.14890126, -3.75306979, -1.56716114],
        [ 2.76767016,  5.07926144,  2.41908275, -8.84556476],
        [ 2.76767016,  2.41782872, -3.6501426 ,  4.30875103],
        [ 4.37345155,  1.77135529,  2.40878105, -1.65462983]],

       [[ 0.        ,  6.47780438, -3.11192536, -2.48821071],
        [-7.03173815,  1.59632335,  0.66651699,  4.75028368],
        [ 1.25297955, -2.23393831, -8.55451105, -1.29356088],
        [-1.25297955, -4.44962381,  4.13120322,  0.96179243],
        [ 7.03173815,  4.4385736 ,  5.40109054,  0.2149866 ]]])

And with a different `fft_length`.

In [17]:
fft2d_np = numpy.fft.rfft2(rnd, (4, 6))
fft2d_cus = fft2d(rnd, (4, 6))
almost_equal(fft2d_np[:4, :], fft2d_cus)

## FFT 2D in ONNX

We use again the numpy API for ONNX.

In [18]:
def onnx_rfft_1d(x, fft_length=None, transpose=True):
    if fft_length is None:
        raise RuntimeError("fft_length must be specified.")
    
    size = fft_length // 2 + 1
    cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)
    if transpose:
        xt = npnx.transpose(x, (1, 0))
        res = npnx.dot(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]
        return npnx.transpose(res, (0, 2, 1))
    else:
        return npnx.dot(cst[:, :, :fft_length], x[:fft_length])


@onnxnumpy_np(signature=NDArrayType(("T:all", ), dtypes_out=('T',)))
def onnx_rfft_2d(x, fft_length=None):
    mat = x[:fft_length[0], :fft_length[1]]
    
    # first FFT
    res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)
    
    # second FFT decomposed on FFT on real part and imaginary part
    res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)
    res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False)    
    res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])
    res = res2_real + res2_imag2
    size = fft_length[1]//2 + 1
    return res[:, :fft_length[0], :size]


fft2d_cus = fft2d(rnd, rnd.shape)
fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)
almost_equal(fft2d_cus, fft2d_onx)

The corresponding ONNX graph.

In [19]:
key = list(onnx_rfft_2d.signed_compiled)[0]
%onnxview onnx_rfft_2d.signed_compiled[key].compiled.onnx_

With a different `fft_length`.

In [20]:
fft2d_cus = fft2d(rnd, (4, 5))
fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))
almost_equal(fft2d_cus, fft2d_onx)

This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise.