# ONNX FFTs

Implementation of a couple of variations of FFT (see [FFT](https://www.tensorflow.org/xla/operation_semantics#fft) in ONNX.

In [1]:
from jyquickhelper import add_notebook_menu
add_notebook_menu()

In [2]:
%load_ext mlprodict

## Signature

We try to use function [FFT](https://www.tensorflow.org/xla/operation_semantics#fft) or [torch.fft.fftn](https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn).

In [12]:
import numpy
from numpy.testing import assert_almost_equal

def numpy_fft(x, fft_type, fft_length, axes):
    """
    Implements FFT

    :param x: input
    :param fft_type: string (see below)
    :param fft_length: length on each axis of axes
    :param axes: axes
    :return: result
    
    * `'FFT`': complex-to-complex FFT. Shape is unchanged.
    * `'IFFT`': Inverse complex-to-complex FFT. Shape is unchanged.
    * `'RFFT`': Forward real-to-complex FFT.
      Shape of the innermost axis is reduced to fft_length[-1] // 2 + 1 if fft_length[-1]
      is a non-zero value, omitting the reversed conjugate part of 
      the transformed signal beyond the Nyquist frequency.
    * `'IRFFT`': Inverse real-to-complex FFT (ie takes complex, returns real).
      Shape of the innermost axis is expanded to fft_length[-1] if fft_length[-1] 
      is a non-zero value, inferring the part of the transformed signal beyond the Nyquist
      frequency from the reverse conjugate of the 1 to fft_length[-1] // 2 + 1 entries.
    """
    if fft_type == 'FFT':
        return numpy.fft.fftn(x, fft_length, axes=axes)
    raise NotImplementedError("Not implemented for fft_type=%r." % fft_type)
    

def test_fct(fct1, fct2, fft_type='FFT', decimal=5):
    dims = [[4,4,4,4],
            [4,5,6,7]]
    lengths_axes = [([2, 2, 2, 2], None),
                    ([2, 3, 4, 5], None),
                    ([2], [3]),
                    ([3], [2])]
    n_test = 0
    for ndim in range(1, 5):
        for dim in dims:
            for length, axes in lengths_axes:
                if axes is None:
                    axes = range(ndim)
                axes = [min(ndim - 1, a) for a in axes]
                di = dim[:ndim]
                le = length[:ndim]
                mat = numpy.random.randn(*di).astype(numpy.float32)
                try:
                    v1 = fct1(mat, fft_type, le, axes=axes)
                except Exception as e:
                    raise AssertionError(
                        "Unable to run %r mat.shape=%r ndim=%r fft_type=%r le=%r "
                        "axes=%r exc=%r" %(
                            fct1, mat.shape, ndim, fft_type, le, axes, e))
                v2 = fct2(mat, fft_type, le, axes=axes)
                try:
                    assert_almost_equal(v1, v2, decimal=decimal)
                except AssertionError as e:
                    raise AssertionError(
                        "Failure mat.shape=%r, fft_type=%r, fft_length=%r" % (
                            mat.shape, fft_type, le)) from e
                n_test += 1
    return n_test
                    
test_fct(numpy_fft, numpy_fft)

32

In [14]:
import torch

def torch_fft(x, fft_type, fft_length, axes):
    xt = torch.tensor(x)
    if fft_type == 'FFT':
        return torch.fft.fftn(xt, fft_length, axes).cpu().detach().numpy()
    
test_fct(numpy_fft, torch_fft)

32

## Numpy implementation

In [21]:
import numpy

def _DFT_cst(N, fft_length, dtype=numpy.float32):
    n = numpy.arange(N).astype(dtype).reshape((-1, 1))
    k = numpy.arange(fft_length).reshape((1, -1)).astype(dtype)
    M = numpy.exp(-2j * numpy.pi * n * k / fft_length)
    return M


def custom_fft(x, fft_type, fft_length, axes):
    if len(axes) != len(fft_length):
        raise ValueError("Length mismatch axes=%r, fft_length=%r." % (
           axes, fft_length))
    if fft_type == 'FFT':
        perm = numpy.arange(len(x.shape)).tolist()        
        res = x
        for i in range(len(fft_length)-1, -1, -1):
            length = fft_length[i]
            axis = axes[i]
            cst = _DFT_cst(x.shape[axis], length)
            if perm[i] == perm[-1]:
                print("-1-", i, res.shape, cst.shape, perm)
                res = numpy.matmul(res, cst).transpose(perm)
            else:
                perm[i], perm[-1] = perm[-1], perm[i]            
                rest = res.transpose(perm)
                print("-2-", i, res.shape, cst.shape, perm, '--', cst.T.shape, rest.shape)
                res = numpy.matmul(res, cst).transpose(perm)
                perm[i], perm[0] = perm[0], perm[i]
        return res
    raise ValueError("Unexpected value for fft_type=%r." % fft_type)

    
img = 1j
shape = (4, )
fft_length = [5,]
axes = [0]
rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * img
# custom_fft(rnd, 'FFT', fft_length, axes), numpy_fft(rnd, 'FFT', fft_length, axes)
# assert_almost_equal(custom_fft(rnd, 'FFT', fft_length, axes),
#                     numpy_fft(rnd, 'FFT', fft_length, axes), decimal=5)

print("-------------")
img = 1j
shape = (4, 3)
fft_length = [3, 2]
axes = [0, 1]
rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * img
custom_fft(rnd, 'FFT', fft_length, axes), numpy_fft(rnd, 'FFT', fft_length, axes)

-------------
-1- 1 (4, 3) (3, 2) [0, 1]
-2- 0 (4, 2) (4, 3) [1, 0] -- (3, 4) (2, 4)


ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 4 is different from 2)