In [71]:
%matplotlib widget
%load_ext autoreload
%autoreload 2
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import animation
from IPython.display import HTML
from ipywidgets import AppLayout, IntSlider, FloatSlider, interact
import pdb
from utils.plotting import animated_polar

plt.style.use(['ggplot','dark_background'])
colors = [x['color'] for x in plt.rcParams['axes.prop_cycle']]
plt.rcParams['grid.linestyle'] = '--'
plt.rcParams['lines.linewidth'] = 3

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Visualizing the Fourier Transform

## Introduction
The Fourier transform is... bla bla. decomposition into sinewaves. 

## Understanding periodic signals

First, let's generate a signal as a linear combination of sinusoids of different frequencies.



In [6]:
def sinewave(frequency, phi=0, alpha=1, length=100):
    omega = (-2 * np.pi * frequency) / length
    x = np.arange(length)
    wave = alpha * np.exp(1j * (omega * x + phi))
    return wave

freqs = [10, 14, 5, 2]
phis = [0, np.pi/4, np.pi*1.5, np.pi/8]
alphas = [1.2, 2, 1, 3]
length = 1000
sig = np.zeros((length), dtype=np.complex64)
for i, f in enumerate(freqs):
    s = sinewave(f, alpha=alphas[i], phi=phis[i], length=length)
    sig += s

In [72]:
plt.figure()
plt.plot(sig.real)
plt.plot(sig.imag)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7fb9c8d09438>]

In [22]:
animated_polar([sig], figsize=(7, 7))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Phase shifts

Now, let's generate a signal by shifting the phase of each of the components by 90 degrees.

In [23]:
shifted_phis = [x + np.pi / 2 for x in phis]
shifted_sig = np.zeros((length), dtype=np.complex64)
for i, f in enumerate(freqs):
    s = sinewave(f, alpha=alphas[i], phi=shifted_phis[i], length=length)
    shifted_sig += s

We can plot these signals side by side:

In [27]:
animated_polar([sig, shifted_sig], figsize=(14, 7))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## The nth roots of unity

In [44]:
def nth_roots_of_unity(n, exponent=1):
    n = int(n)
    roots_of_unity = []
    for i in range(n):
        root = np.exp((2 * np.pi * 1j * i * exponent) / n)
        roots_of_unity.append(root)
    return roots_of_unity

## The Fourier Transform

Definition of Fourier transform

Discrete Fourier transform

In [45]:
def dft_matrix(n):
    matrix = np.zeros((n, n), dtype=np.complex64)
    for k in range(n):
        roots = nth_roots_of_unity(n, k)
        matrix[:, k] = roots
    return matrix

In [46]:
dft32 = dft_matrix(32)
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
axes[0].imshow(dft32.real)
axes[1].imshow(dft32.imag)
for ax in axes:
    ax.grid(False)
    ax.axis('off')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [75]:
dft32 = dft_matrix(32)
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
axes[0].plot(dft32.real[:, :5]);
axes[1].plot(dft32.imag[:, :5]);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [89]:
# Generate 1000 x 1000 DFT matrix
dft = dft_matrix(1000)

# Compute the Fourier transform of original and shifted signals
fourier_projection = sig * dft
fourier_transform = np.sum(fourier_projection, axis=1)

shifted_fourier_projection = shifted_sig * dft
shifted_fourier_transform = np.sum(shifted_fourier_projection, axis=1)

# Find the top Fourier coefficients
top_bases = np.argsort(abs(fourier_transform))[::-1][:4]
top_curves = [fourier_projection[i] for i in top_bases]

shifted_top_bases = np.argsort(abs(shifted_fourier_transform))[::-1][:4]
shifted_top_curves = [shifted_fourier_projection[i] for i in shifted_top_bases]

print('Top Frequencies: {}'.format([x for x in top_bases]))

Top Frequencies: [2, 14, 10, 5]


In [53]:
animated_polar(top_curves, interval=30, plot_mean=True, figsize=(20, 7))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [81]:
animated_polar(shifted_top_curves, interval=30, plot_mean=True, figsize=(20, 7))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Equivariance

In [90]:
shift = np.exp(1j * np.pi)
fourier_projection_then_shift = fourier_projection * shift
fourier_then_shift = fourier_transform * shift
inverse_shifted = np.conjugate(dft) @ fourier_projection_then_shift

In [91]:
animated_polar([shifted_fourier_projection[5], fourier_projection_then_shift[5]], interval=30, plot_mean=True, figsize=(15, 7), ax_titles=['Shift -> Fourier', 'Fourier -> Shift'])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [177]:
phi_slider = FloatSlider(orientation='horizontal',
                           description='Phase Shift:',
                           value=0,
                           min=0,
                           max=2*np.pi,
                           step=np.pi/8)

fig = plt.figure(figsize=(15, 7))
spec = gridspec.GridSpec(ncols=3, nrows=1, figure=fig,  width_ratios= [1, 1, 0.5])
ax1 = fig.add_subplot(spec[0, 0], polar=True)
ax2 = fig.add_subplot(spec[0, 1], polar=True)
ax3 = fig.add_subplot(spec[0, 2])


fig.canvas.header_visible = False


ax1.title.set_text('Signal')
ax2.title.set_text('Fourier Transform')
ax3.title.set_text('Power Spectrum')


freqs = [10, 14, 5, 2]
phis = [0, np.pi/4, np.pi*1.5, np.pi/8]
alphas = [1.2, 2, 1, 3]
length = 500
sig = np.zeros((length), dtype=np.complex64)
for i, f in enumerate(freqs):
    s = sinewave(f, alpha=alphas[i], phi=phis[i], length=length)
    sig += s
    
dft = dft_matrix(500)

fourier_transform = sig @ dft
power_spectrum = abs(fourier_transform * np.conjugate(fourier_transform))[:20]
power_spectrum /= power_spectrum.max()

ax1.plot(np.angle(sig), abs(sig), color=colors[5])
ax1.plot([0, 0], [0, max(abs(sig))], color='darkviolet', linewidth=5)
for f in fourier_transform:
    ax2.plot([0, np.angle(f)], [0, abs(f)])
ax2.plot([0, 0], [0, max(abs(fourier_transform))], color='darkviolet', linewidth=5)
for i in range(20):
    ax3.bar(i, power_spectrum[i])
ax3.spines["top"].set_visible(False)
ax3.spines["left"].set_visible(False)
ax3.spines["right"].set_visible(False)
# ax3.yaxis.set_ticklabels([])

ax3.grid(False)


def shift_signal(phi):
    shifted_phis = [x + phi for x in phis]
    shifted_sig = np.zeros((length), dtype=np.complex64)
    for i, f in enumerate(freqs):
        s = sinewave(f, alpha=alphas[i], phi=shifted_phis[i], length=length)
        shifted_sig += s
    return shifted_sig

def shift_fourier(fourier, phi):
    phasor = np.exp(1j * phi)
    shifted_fourier = fourier * phasor
    return shifted_fourier


def shift(phi):
    if len(ax1.lines) > 1:
        ax1.lines = []
        ax2.lines = []
        
    new_sig = shift_signal(phi)
    new_fourier = new_sig @ dft

    shifted_fourier = shift_fourier(fourier_transform, phi)

    ax1.plot(np.angle(new_sig), abs(new_sig), color=colors[5])
    ax1.plot([0, phi], [0, abs(max(new_sig))], color='darkviolet', linewidth=5)


    for f in new_fourier:
        ax2.plot([0, np.angle(f)], [0, abs(f)])

    ax2.plot([0, phi], [0, max(abs(fourier_transform))], color='darkviolet', linewidth=5)


    fig.canvas.draw()
    fig.canvas.flush_events()
    
interact(shift, phi=phi_slider);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(FloatSlider(value=0.0, description='Phase Shift:', max=6.283185307179586, step=0.3926990…

## Invariance