# ONNX FFTs

Implementation of a couple of variations of FFT (see [FFT](https://www.tensorflow.org/xla/operation_semantics#fft) in ONNX.

In [1]:
from jyquickhelper import add_notebook_menu
add_notebook_menu()

In [2]:
%load_ext mlprodict

## Signature

We try to use function [FFT](https://www.tensorflow.org/xla/operation_semantics#fft) or [torch.fft.fftn](https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn).

In [18]:
import numpy
from numpy.testing import assert_almost_equal

def numpy_fft(x, fft_type, fft_length):
    """
    Implements FFT

    :param x: input
    :param fft_type: string (see below)
    :param fft_length: length on each axis
    :return: result
    
    * `'FFT`': complex-to-complex FFT. Shape is unchanged.
    * `'IFFT`': Inverse complex-to-complex FFT. Shape is unchanged.
    * `'RFFT`': Forward real-to-complex FFT.
      Shape of the innermost axis is reduced to fft_length[-1] // 2 + 1 if fft_length[-1]
      is a non-zero value, omitting the reversed conjugate part of 
      the transformed signal beyond the Nyquist frequency.
    * `'IRFFT`': Inverse real-to-complex FFT (ie takes complex, returns real).
      Shape of the innermost axis is expanded to fft_length[-1] if fft_length[-1] 
      is a non-zero value, inferring the part of the transformed signal beyond the Nyquist
      frequency from the reverse conjugate of the 1 to fft_length[-1] // 2 + 1 entries.
    """
    if fft_type == 'FFT':
        return numpy.fft.fftn(x, fft_length)
    raise NotImplementedError("Not implemented for fft_type=%r." % fft_type)
    

def test_fct(fct1, fct2, fft_type='FFT', decimal=5):
    dims = [[4,4,4,4],
            [4,5,6,7]]
    lengths = [[2, 2, 2, 2],
                [2, 3, 4, 5]]
    n_test = 0
    for ndim in range(1, 5):
        for dim in dims:
            for length in lengths:
                di = dim[:ndim]
                le = length[:ndim]
                mat = numpy.random.randn(*di).astype(numpy.float32)
                v1 = fct1(mat, fft_type, le)
                v2 = fct2(mat, fft_type, le)
                try:
                    assert_almost_equal(v1, v2, decimal=decimal)
                except AssertionError as e:
                    raise AssertionError(
                        "Failure mat.shape=%r, fft_type=%r, fft_length=%r" % (
                            mat.shape, fft_type, le)) from e
                n_test += 1
    return n_test
                    
test_fct(numpy_fft, numpy_fft)

16

In [19]:
import torch

def torch_fft(x, fft_type, fft_length):
    xt = torch.tensor(x)
    if fft_type == 'FFT':
        return torch.fft.fftn(xt, fft_length).cpu().detach().numpy()
    
test_fct(numpy_fft, torch_fft)

16

## Numpy implementation

In [98]:
import numpy

def _DFT_cst(N, fft_length, dtype=numpy.float32):
    n = numpy.arange(N).astype(dtype).reshape((-1, 1))
    k = numpy.arange(fft_length).reshape((1, -1)).astype(dtype)
    M = numpy.exp(-2j * numpy.pi * n * k / fft_length)
    return M


def custom_fft(x, fft_type, fft_length):
    if len(x.shape) != len(fft_length):
        raise ValueError("Length mismatch x.shape=%r, fft_length=%r." % (
            x.shape, fft_length))
    if fft_type == 'FFT':
        perm = numpy.arange(len(x.shape)).tolist()
        res = x.copy()
        for i in range(len(shape)-1, -1, -1):
            cst = _DFT_cst(x.shape[i], fft_length[i])
            perm[i], perm[0] = perm[0], perm[i]
            rest = res.transpose(perm)
            print("-", i, res.shape, cst.shape, perm, '--', cst.T.shape, rest.shape)
            res = numpy.matmul(res, cst).transpose(perm)
            perm[i], perm[0] = perm[0], perm[i]
        return res
    raise ValueError("Unexpected value for fft_type=%r." % fft_type)

    
img = 1j
shape = (4, )
fft_length = [5,]
rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * img
# custom_fft(rnd, 'FFT', fft_length), numpy_fft(rnd, 'FFT', fft_length)
# assert_almost_equal(custom_fft(rnd, 'FFT', fft_length), numpy_fft(rnd, 'FFT', fft_length), decimal=5)

print("-------------")
img = 1j
shape = (4, 3)
fft_length = [3, 2]
rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * img
custom_fft(rnd, 'FFT', fft_length), numpy_fft(rnd, 'FFT', fft_length)

-------------
- 1 (4, 3) (3, 2) [1, 0] -- (2, 3) (3, 4)
- 0 (2, 4) (4, 3) [0, 1] -- (3, 4) (2, 4)


(array([[ 6.27777313+1.95827881j,  4.37018612+1.98945042j,
         -1.25988636-0.38273674j],
        [-0.81455859+0.29279864j, -3.58824843+3.90278785j,
         -0.96516085-2.86931853j]]),
 array([[ 4.28430848+2.75459542j, -0.92570433+0.20936179j],
        [ 5.57509299+0.74344773j, -0.50102344+1.77703045j],
        [-0.60901483-0.62531727j,  1.56802951-3.99165265j]]))