In [11]:
import matplotlib.pyplot as plt
import numpy as np
import math

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [12]:
# From https://towardsdatascience.com/fast-fourier-transform-937926e591cb
# Example implementation in 1D

def dft(x):
    # discrete fourier 1D
    
    x = np.asarray(x, dtype=float)
    N = x.shape[0]
    n = np.arange(N)
    k = n.reshape((N, 1))
    M = np.exp(-2j * np.pi * k * n / N)
    
    return np.dot(M, x)


def fft(x):
    # fast fourier 1D (recursive)

    x = np.asarray(x, dtype=float)
    N = x.shape[0]
    if N % 2 > 0:
        raise ValueError("must be a power of 2")
    elif N <= 2:
        # at the bottom do a simple DFT
        return dft(x)
    else:
        X_even = fft(x[::2])
        X_odd = fft(x[1::2])
        # prefactor for DFT odd
        terms = np.exp(-2j * np.pi * np.arange(N) / N)
        return np.concatenate([X_even + terms[:int(N/2)] * X_odd,
                               X_even + terms[int(N/2):] * X_odd])

    
def fft_v(x):
    # fast fourier 1D (vectors)

    x = np.asarray(x, dtype=float)
    N = x.shape[0]
    if np.log2(N) % 1 > 0:
        raise ValueError("must be a power of 2")
        
    N_min = min(N, 2)
    
    n = np.arange(N_min)
    k = n[:, None]
    # DFT of the lowest level
    M = np.exp(-2j * np.pi * n * k / N_min)
    X = np.dot(M, x.reshape((N_min, -1)))
    while X.shape[0] < N:
            X_even = X[:, :int(X.shape[1] / 2)]
            X_odd = X[:, int(X.shape[1] / 2):]
            terms = np.exp(-1j * np.pi * np.arange(X.shape[0])
                            / X.shape[0])[:, None]
            X = np.vstack([X_even + terms * X_odd,
                           X_even - terms * X_odd])
    return X.ravel()


x = np.random.random(1024)
np.allclose(fft(x), np.fft.fft(x))
np.allclose(dft(x), np.fft.fft(x))
np.allclose(fft_v(x), np.fft.fft(x))

%timeit fft(x)
%timeit fft_v(x)
%timeit np.fft.fft(x)

14.1 ms ± 277 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
321 µs ± 10.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
4.96 µs ± 87.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [13]:
N = len(x)
n = np.arange(N)
k = n.reshape((N,1))
M = np.exp(-2j * np.pi * k * n / N)
print(k.shape)
print(n.shape)
print(M.shape)
dft_out = np.dot(M, x)
print(dft_out.shape)

(1024, 1)
(1024,)
(1024, 1024)
