<a href="https://colab.research.google.com/github/sascha-kirch/ML_Notebooks/blob/main/2D_FFTs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 2D FFT
Notebook to investigate the behavior of 2D FFTs and filters in Frequency Domain.

Numpy Docu: https://numpy.org/doc/stable/reference/routines.fft.html


In [None]:
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [None]:
def HorizontalSine(shape, wavelength):
  X, Y = np.meshgrid(np.arange(0, shape[0], 1), np.arange(0, shape[1], 1))
  return np.sin(2 * np.pi * X / wavelength)

def VerticalSine(shape, wavelength):
  X, Y = np.meshgrid(np.arange(0, shape[0], 1), np.arange(0, shape[1], 1))
  return np.sin(2 * np.pi * Y / wavelength)

def DiagonalSine(shape, wavelength, angle = np.pi / 4):
  X, Y = np.meshgrid(np.arange(0, shape[0], 1), np.arange(0, shape[1], 1))
  return np.sin( 2*np.pi*(X*np.cos(angle) + Y*np.sin(angle)) / wavelength)

def Gauss(shape, A=1):
  x_mean = int(shape[0]/2)
  y_mean = int(shape[1]/2)
  x_sigma = int(shape[0]/10)
  y_sigma = int(shape[1]/10)
  X, Y = np.meshgrid(np.arange(0, shape[0], 1), np.arange(0, shape[1], 1))
  return A*np.exp(-((X-x_mean)**2/(2*x_sigma**2)+(Y-y_mean)**2/(2*y_sigma**2)))

def Circle(shape, r=10): #Lowpass
  X, Y = np.meshgrid(np.arange(0, shape[0], 1), np.arange(0, shape[1], 1))
  x_center = int(shape[0]/2)
  y_center = int(shape[1]/2)
  image = np.sqrt((X-x_center)**2 + (Y-y_center)**2)
  return image <= r # all values smaller or equal as r are set true

def NegativeCircle(shape, r=10): #Highpass
  X, Y = np.meshgrid(np.arange(0, shape[0], 1), np.arange(0, shape[1], 1))
  x_center = int(shape[0]/2)
  y_center = int(shape[1]/2)
  image = np.sqrt((X-x_center)**2 + (Y-y_center)**2)
  return image > r # all values bigger as r are set true

def Rect(shape, height = 30, width = 10):
  image = np.zeros(shape)
  x_min = int(shape[0]/2 - width/2)
  x_max = int(shape[0]/2 + width/2)
  y_min = int(shape[1]/2 - height/2)
  y_max = int(shape[1]/2 + height/2)
  image[y_min:y_max,x_min:x_max]=1
  return image

def VerticalRect(shape):
  return Rect(shape, shape[0]/4,shape[1]/8)

def HorizontalRect(shape):
  return Rect(shape, shape[0]/8,shape[1]/4)

plt.imshow(Gauss((1025,1025)), cmap='gray')
plt.colorbar()

In [None]:
def PlotFFT(image, plot3D = True, realFFT = False, xlim = None, ylim=None):

  # Spectrum calculation
  if realFFT:
    # FFT Shift only over height axis if real FFT
    spectrum = np.fft.fftshift(np.fft.rfft2(image,axes=(- 2, - 1)),axes = [-2])
    inverse_DFT = np.fft.irfft2(np.fft.ifftshift(spectrum,axes = [-2]))
    # axis labels for spectrum to allign with shifted signal
    extent = [
      0,np.shape(spectrum)[-1],
      -int(np.shape(spectrum)[-2]/2),int(np.shape(spectrum)[-2]/2)]
  else:
    spectrum = np.fft.fftshift(np.fft.fft2(image,axes=(- 2, - 1)))
    #spectrum = tf.signal.fftshift(tf.signal.fft2d(image))
    inverse_DFT = np.fft.ifft2(np.fft.ifftshift(spectrum)).real
    extent = [
      -int(np.shape(spectrum)[-1]/2),int(np.shape(spectrum)[-1]/2),
      -int(np.shape(spectrum)[-2]/2),int(np.shape(spectrum)[-2]/2)]
    
  amplitude_spectrum = np.log( #scaling in amplitude
        np.abs(spectrum)**2 # amplitudes of complex spectrum
    )
  phase_spectrum = np.angle(spectrum)

  #3D Plot preparation
  if plot3D:
      rows = 2
  else:
      rows = 1

  # Plotting
  fig = plt.figure(figsize=[30,6*rows])
  fig.suptitle("2D DFT & iDFT", fontsize=20)

  ax = fig.add_subplot(rows, 4, 1)
  ax.set_title('Image')
  plt.imshow(image, cmap='gray')
  plt.colorbar()

  ax = fig.add_subplot(rows, 4, 2)
  ax.set_title('2D DFT Amplitude')
  plt.imshow(amplitude_spectrum, cmap='jet',extent=extent)
  if xlim:
    plt.xlim(xlim)
  if ylim:
    plt.ylim(ylim)
  plt.colorbar()
    
  ax = fig.add_subplot(rows, 4, 3)
  ax.set_title('2D DFT Phase')
  plt.imshow(phase_spectrum, cmap='jet',extent=extent)
  if xlim:
    plt.xlim(xlim)
  if ylim:
    plt.ylim(ylim)
  plt.colorbar()

  ax = fig.add_subplot(rows, 4, 4)
  ax.set_title('2D iDFT')
  plt.imshow(inverse_DFT, cmap='gray')
  plt.colorbar()
    
  if plot3D:
      #print(np.shape(spectrum))
      X, Y = np.meshgrid(range(0,np.shape(image)[0]), range(0,np.shape(image)[1]))
      if realFFT:
        X_freq, Y_freq = np.meshgrid(range(0,np.shape(spectrum)[-2]), range(-int(np.shape(spectrum)[-1]/2),int(np.shape(spectrum)[-1]/2)))
      else:
        X_freq, Y_freq = np.meshgrid(range(-int(np.shape(spectrum)[-2]/2),int(np.shape(spectrum)[-2]/2)), range(-int(np.shape(spectrum)[-1]/2),int(np.shape(spectrum)[-1]/2)))

      ax = fig.add_subplot(2, 4, 5,projection='3d')
      ax.set_title('Image')
      ax.plot_surface(X,Y,image,cmap='gray')

      ax = fig.add_subplot(2, 4, 6,projection='3d')
      ax.set_title('2D DFT Amplitude')
      ax.plot_surface(X_freq,Y_freq,amplitude_spectrum,cmap='jet')

      ax = fig.add_subplot(2, 4, 7,projection='3d')
      ax.set_title('2D DFT Phase')
      ax.plot_surface(X_freq,Y_freq,phase_spectrum,cmap='jet')

      ax = fig.add_subplot(2, 4, 8,projection='3d')
      ax.set_title('2D iDFT')
      ax.plot_surface(X,Y,inverse_DFT,cmap='gray')

  plt.show()

In [None]:
PlotFFT(Gauss(shape=(1024,1024)),plot3D = True,realFFT = False, xlim=None, ylim=None)

In [None]:
!wget https://raw.githubusercontent.com/sascha-kirch/sascha-kirch.github.io/main/docs/assets/images/emojos.png

In [None]:
imageFilePath = "/content/emojos.png"
image = cv2.imread(imageFilePath,0) # read Gray Scale
H,W = np.shape(image)
side_length = min(H,W)
image = image[0:side_length,0:side_length] # crop and only use single color channel

In [None]:
PlotFFT(image)

In [None]:
def PlotFFTWithFilter(image,filter):

  # Spectrum calculation
  spectrum = np.fft.fftshift(np.fft.fft2(image,axes=(- 2, - 1)))
  amplitude_spectrum = np.log( #scaling in amplitude
        np.abs(spectrum)**2 # amplitudes of complex spectrum
    )
  phase_spectrum = np.angle(spectrum)
  inverse_DFT = np.fft.ifft2(np.fft.ifftshift(spectrum)).real

  filtered_spectrum = spectrum * filter
  filtered_amplitude_spectrum = np.log(np.abs(filtered_spectrum)**2)
  filtered_phase_spectrum = np.angle(filtered_spectrum)
  filtered_inverse_DFT = np.fft.ifft2(np.fft.ifftshift(filtered_spectrum)).real

  # Plotting
  fig = plt.figure(figsize=[30,12])
  fig.suptitle("2D DFT & iDFT", fontsize=20)

  ax = fig.add_subplot(2, 4, 1)
  ax.set_title('Image')
  plt.imshow(image, cmap='gray')
  plt.colorbar()

  ax = fig.add_subplot(2, 4, 2)
  ax.set_title('2D DFT Amplitude')
  plt.imshow(amplitude_spectrum, cmap='jet')
  #plt.xlim([448, 576])
  #plt.ylim([576, 448])
  plt.colorbar()

  ax = fig.add_subplot(2, 4, 3)
  ax.set_title('2D DFT Phase')
  plt.imshow(phase_spectrum, cmap='jet')
  plt.colorbar()

  ax = fig.add_subplot(2, 4, 4)
  ax.set_title('2D iDFT')
  plt.imshow(inverse_DFT, cmap='gray', vmin=0, vmax=255)
  plt.colorbar()

  ax = fig.add_subplot(2, 4, 5)
  ax.set_title('Filter F-Domain')
  plt.imshow(filter, cmap='gray')
  plt.colorbar()

  ax = fig.add_subplot(2, 4, 6)
  ax.set_title('Filtered Amplitude Spectrum')
  plt.imshow(filtered_amplitude_spectrum, cmap='jet')
  plt.colorbar()

  ax = fig.add_subplot(2, 4, 7)
  ax.set_title('Filtered Phase Spectrum')
  plt.imshow(filtered_phase_spectrum, cmap='jet')
  plt.colorbar()

  ax = fig.add_subplot(2, 4, 8)
  ax.set_title('Filtered iDFT')
  plt.imshow(filtered_inverse_DFT, cmap='gray', vmin=0, vmax=255)
  plt.colorbar()

  plt.show()

#image = DiagonalSine(shape=(1025,1025),wavelength=128)
filter = Rect(np.shape(image),height= 100,width=200)
PlotFFTWithFilter(image, filter)