# Fourier transforms in 2 dimensions

These interactive demos teach the relationship between a 2D signal in real space, and it's 2D Fourier transform.

In [1]:
# Interactive widgets.
import ipywidgets as widgets

# Custom abTEM code to enable interactivity.
from abtem.visualize.interactive.artists import ImageArtist
from abtem.visualize.interactive.artists import LinesArtist
from abtem.visualize.interactive.artists import ScatterArtist
from abtem.visualize.interactive.utils import throttle
from abtem.visualize.interactive.canvas import Canvas
from abtem.visualize.utils import domain_coloring

In [2]:
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.animation as animation
import matplotlib.cm as cm
from matplotlib.colors import rgb_to_hsv
from matplotlib.colors import hsv_to_rgb

In [231]:
# Define custom functions

def color_complex(
    exit_wave, 
    amp_range=(0,1.0), 
    amp_power=1, 
    color_phase=True,
    ):
    """
    Generate a colored image from a complex exit wave.

    Inputs:
        exit_wave  - 2D complex number array
        amp_range - min and max amplitude
        amp_power -power law scaling for amlitude
        color_phase - boolean value for whether to colour by phase
    Outputs:
        im_rgb    - RGB output image
    
    """
    
    pha = np.angle(exit_wave)
    amp = np.abs(exit_wave)
    amp = np.clip((amp - amp_range[0]) / (amp_range[1] - amp_range[0]), 0, 1)
    if amp_power != 1:
        amp = amp ** amp_power

    if color_phase:
        # Generate blue-corrected hue colormap
        num_colors = 256
        t = np.linspace(0,1,num_colors)
        cmap = cm.hsv(t)
        # cmap = rgb_to_hsv(cmap[:,0:3])
        # c_scale = np.sin(np.maximum(1 - 3*np.abs(t - 0.667),0)*(np.pi/2))**2
        # cmap[:,1] = cmap[:,1] - 0.5 * c_scale
        # cmap = hsv_to_rgb(cmap)

        # Generate image
        im_rgb = np.ones((exit_wave.shape[0],exit_wave.shape[1],3))
        im_rgb[:,:,0] = np.mod(pha/(2*np.pi) + 0.5,1)
        im_rgb[:,:,2] =  amp

        # correct blue values from hue map
        c_scale = np.sin(np.maximum(1 - 6*np.abs(im_rgb[:,:,0] - 0.667),0)*(np.pi/2))**2
        im_rgb[:,:,1] = im_rgb[:,:,1]*(1 - 0.333*c_scale)
        im_rgb[:,:,2] = im_rgb[:,:,2]*(1 + 0.333*c_scale)

        # correct green values (and lime-green / yellow)
        c_scale = np.sin(np.maximum(1 - 4*np.abs(im_rgb[:,:,0] - 0.333),0)*(np.pi/2))**2
        im_rgb[:,:,2] = im_rgb[:,:,2]*(1 - 0.167*c_scale)      

        # Output image
        im_rgb[:,:,2] = np.clip(im_rgb[:,:,2], 0, 1)
        im_rgb = hsv_to_rgb(im_rgb)
    else:
        im_rgb = np.tile(amp[..., None], [1, 1, 3])
  
    return im_rgb



def draw_plane_waves(
    image_size = [64,64], 
    freq = np.array([[0,0]], dtype='int'),
    vals = np.array([0],dtype='complex'),
    # mark_origin = True,
    ):
    """
    Generate a complex image and its complex inverse FFT image with N spatial frequencies

    Inputs:
        freq  - (Nx2) indices of spatial frequencies in pixels (modulo)
        vals - (N) magnitude of frequencies in freq
        image_size - size of the image
    Outputs:
        Psi    - fftshifted Fourier space image
        psi    - fftshifted real space image
    
    """
    
    Psi = np.zeros(image_size, dtype='complex')
    
    for a0 in range(freq.shape[0]):
        x_ind = np.mod(freq[a0,0],image_size[0])
        y_ind = np.mod(freq[a0,1],image_size[1])
        
        Psi[x_ind,y_ind] = vals[a0]
    
    psi = np.fft.ifft2(Psi) * np.prod(image_size)
    # if mark_origin:
    #     Psi[0,0] = 1
    Psi = np.fft.fftshift(Psi)
    psi = np.fft.fftshift(psi) / freq.shape[0]
    
    return Psi, psi
    
    

def atomic_lattice(
    image_size = [256,256], 
    u_lat = np.array([8,0], dtype='float'),
    v_lat = np.array([0,8], dtype='float'),
    atom_power = np.array(2,dtype='float'),
    sigma_envelop=1024,
    ):
    """
    Generate a complex image and its complex inverse FFT image with N spatial frequencies

    Inputs:
        freq  - (Nx2) indices of spatial frequencies in pixels (modulo)
        vals - (N) magnitude of frequencies in freq
        image_size - size of the image
    Outputs:
        Psi    - fftshifted Fourier space image
        psi    - fftshifted real space image
    
    """
    
    Psi = np.zeros(image_size, dtype='complex')
    
    # Reciprocal lattice
    qu = np.round((image_size[0]/2) * u_lat / (np.linalg.norm(u_lat)**2)).astype('int')
    qv = np.round((image_size[1]/2) * v_lat / (np.linalg.norm(v_lat)**2)).astype('int')
    
    b = 0.125
    
    x_ind = np.mod(qu[0],image_size[0])
    y_ind = np.mod(qu[1],image_size[1])
    Psi[x_ind,y_ind] = b
    x_ind = np.mod(-qu[0],image_size[0])
    y_ind = np.mod(-qu[1],image_size[1])
    Psi[x_ind,y_ind] = b
    
    x_ind = np.mod(qv[0],image_size[0])
    y_ind = np.mod(qv[1],image_size[1])
    Psi[x_ind,y_ind] = b
    x_ind = np.mod(-qv[0],image_size[0])
    y_ind = np.mod(-qv[1],image_size[1])
    Psi[x_ind,y_ind] = b
    
    #     x_ind = np.mod(qu[0]+qv[0],image_size[0])
    #     y_ind = np.mod(qu[1]+qv[1],image_size[1])
    #     Psi[x_ind,y_ind] = 0.5*b
    #     x_ind = np.mod(-qu[0]+qv[0],image_size[0])
    #     y_ind = np.mod(-qu[1]+qv[1],image_size[1])
    #     Psi[x_ind,y_ind] = 0.5*b

    #     x_ind = np.mod(qu[0]-qv[0],image_size[0])
    #     y_ind = np.mod(qu[1]-qv[1],image_size[1])
    #     Psi[x_ind,y_ind] = 0.5*b
    #     x_ind = np.mod(-qu[0]-qv[0],image_size[0])
    #     y_ind = np.mod(-qu[1]-qv[1],image_size[1])
    #     Psi[x_ind,y_ind] = 0.5*b
    
    # Generate atomic lattice
    psi = np.abs(np.real(0.5 + np.fft.ifft2(Psi) * np.prod(image_size)))**atom_power
    
    # Envelope function
    # ya,xa = np.meshgrid()
    x = np.arange(image_size[0]) - image_size[0]/2
    y = np.arange(image_size[1]) - image_size[1]/2
    psi = psi * np.exp((x[:,None]**2 + y[None,:]**2)/(-2*sigma_envelop**2))
    
    # Inverse FFT
    Psi = 64 * np.fft.fft2(np.fft.ifftshift(psi)) / np.prod(image_size)
    Psi = np.fft.fftshift(Psi)
    Psi = (np.abs(Psi)**0.5) * np.exp(1j*np.angle(Psi))
    
    return Psi, psi


    

def atomic_lattice_polarization(
    image_size = [256,256], 
    u_lat = np.array([8,0], dtype='float'),
    v_lat = np.array([0,8], dtype='float'),
    atom_power = np.array(8,dtype='float'),
    sigma_envelop=64,
    polarization=np.array([0*np.pi,0*np.pi]),
    ):
    """
    Generate a complex image and its complex inverse FFT image with N spatial frequencies

    Inputs:
        freq  - (Nx2) indices of spatial frequencies in pixels (modulo)
        vals - (N) magnitude of frequencies in freq
        image_size - size of the image
    Outputs:
        Psi    - fftshifted Fourier space image
        psi    - fftshifted real space image
    
    """
    
    Psi = np.zeros(image_size, dtype='complex')
    Psi2 = np.zeros(image_size, dtype='complex')
    
    # Reciprocal lattice
    qu = np.round((image_size[0]/2) * u_lat / (np.linalg.norm(u_lat)**2)).astype('int')
    qv = np.round((image_size[1]/2) * v_lat / (np.linalg.norm(v_lat)**2)).astype('int')
    
    b = 0.125
    
    x_ind = np.mod(qu[0],image_size[0])
    y_ind = np.mod(qu[1],image_size[1])
    Psi[x_ind,y_ind] = b
    x_ind = np.mod(-qu[0],image_size[0])
    y_ind = np.mod(-qu[1],image_size[1])
    Psi[x_ind,y_ind] = b 
    
    x_ind = np.mod(qv[0],image_size[0])
    y_ind = np.mod(qv[1],image_size[1])
    Psi[x_ind,y_ind] = b 
    x_ind = np.mod(-qv[0],image_size[0])
    y_ind = np.mod(-qv[1],image_size[1])
    Psi[x_ind,y_ind] = b 
    
    b = -0.10
    p0 = (polarization[0] + polarization[1]) / np.sqrt(2)
    p1 = (polarization[0] - polarization[1]) / np.sqrt(2)
    # p0 = polarization[0] 
    # p1 = polarization[1]
    
    x_ind = np.mod(qu[0],image_size[0])
    y_ind = np.mod(qu[1],image_size[1])
    Psi2[x_ind,y_ind] = b * np.exp(1j*p0)
    x_ind = np.mod(-qu[0],image_size[0])
    y_ind = np.mod(-qu[1],image_size[1])
    Psi2[x_ind,y_ind] = b * np.exp(-1j*p0)
    
    x_ind = np.mod(qv[0],image_size[0])
    y_ind = np.mod(qv[1],image_size[1])
    Psi2[x_ind,y_ind] = b * np.exp(1j*p1)
    x_ind = np.mod(-qv[0],image_size[0])
    y_ind = np.mod(-qv[1],image_size[1])
    Psi2[x_ind,y_ind] = b * np.exp(-1j*p1)
    
    
    # Generate atomic lattice
    psi = np.abs(np.real(0.5 + np.fft.ifft2(Psi) * np.prod(image_size)))**atom_power \
        + np.abs(np.real(0.5 + np.fft.ifft2(Psi2) * np.prod(image_size)))**atom_power
    
    # Envelope function
    # ya,xa = np.meshgrid()
    x = np.arange(image_size[0]) - image_size[0]/2
    y = np.arange(image_size[1]) - image_size[1]/2
    psi = psi * np.exp((x[:,None]**2 + y[None,:]**2)/(-2*sigma_envelop**2))
    
    # Inverse FFT
    Psi = 128 * np.fft.fft2(np.fft.ifftshift(psi)) / np.prod(image_size)
    # Psi = 256 * np.fft.fft2(psi) / np.prod(image_size)
    Psi = np.fft.fftshift(Psi)
    Psi = (np.abs(Psi)**0.5) * np.exp(1j*(np.angle(Psi) - np.pi/2))
    
    return Psi, psi
    



# Complex plane wave / single spatial frequency

In [117]:
# Canvases are layout areas where things are plotted.
canvas1 = Canvas(height=400, width=400, fig_margin={'top': 0, 'bottom': 0, 'left': 0, 'right': 0})
canvas2 = Canvas(height=400, width=400, fig_margin={'top': 0, 'bottom': 0, 'left': 0, 'right': 0})
canvas1.y_scale.reverse = True
canvas2.y_scale.reverse = True
# canvas1.x_label = 'Fourier Space'


# Artists do the plotting.
artist1 = ImageArtist(autoadjust_colorscale = False, rgb=True)
artist2 = ImageArtist(autoadjust_colorscale = False, rgb=True)
# canvas.artists = {'image': artist1, 'image': artist2}
canvas1.artists = {'image': artist1}
canvas2.artists = {'image': artist2}

slider1 = widgets.IntSlider(description='kx', min=-32, max=32, value=1, step=1)
slider2 = widgets.IntSlider(description='ky', min=-32, max=32, value=3, step=1)
slider3 = widgets.FloatSlider(description='Phase', min=0, max=2*np.pi, value=0.0, step=np.pi/256)
check1 = widgets.Checkbox(value=True,description='Show Phase')

def slider_change(*args):
    kx = slider1.value
    ky = slider2.value
    phase = slider3.value
    
    Psi, psi = draw_plane_waves(
        freq = np.array([[kx,ky]], dtype='int'),
        vals = np.array([np.exp(1j*phase)],dtype='complex'),
    )
    
    artist1.image = color_complex(psi.T, color_phase=check1.value, amp_range=[0,1])
    artist2.image = color_complex(Psi.T, color_phase=check1.value, amp_range=[0,1])

# Activate an observer to receive the interaction.
slider1.observe(slider_change, 'value')
slider2.observe(slider_change, 'value')
slider3.observe(slider_change, 'value')
check1.observe(slider_change, 'value')
slider1.style = {'description_width': '50px'}
slider2.style = {'description_width': '50px'}
slider3.style = {'description_width': '50px'}
check1.style  = {'description_width': '50px'}

# Set up and display the widgets in a grid.
slider_change()
widgets.HBox([canvas1.widget, canvas2.widget, widgets.VBox([slider1, slider2, slider3, check1])])

HBox(children=(VBox(children=(HBox(children=(HBox(layout=Layout(width='0px')), HTML(value="<p style='font-size…

# Real-valued plane wave / pair of spatial frequencies

In [118]:
canvas1 = Canvas(height=400, width=400, fig_margin={'top': 0, 'bottom': 0, 'left': 0, 'right': 0})
canvas2 = Canvas(height=400, width=400, fig_margin={'top': 0, 'bottom': 0, 'left': 0, 'right': 0})
canvas1.y_scale.reverse = True
canvas2.y_scale.reverse = True

# Artists do the plotting.
artist1 = ImageArtist(autoadjust_colorscale = False, rgb=True)
artist2 = ImageArtist(autoadjust_colorscale = False, rgb=True)
canvas1.artists = {'image': artist1}
canvas2.artists = {'image': artist2}

slider1 = widgets.IntSlider(description='kx', min=-32, max=32, value=1, step=1)
slider2 = widgets.IntSlider(description='ky', min=-32, max=32, value=3, step=1)
slider3 = widgets.FloatSlider(description='Phase', min=0, max=2*np.pi, value=np.pi*2/3, step=np.pi/256)
check1 = widgets.Checkbox(value=True,description='Show Phase')

def slider_change(*args):
    kx = slider1.value
    ky = slider2.value
    phase = slider3.value
    
    Psi, psi = draw_plane_waves(
        freq = np.array([
            [kx,ky],
            [-kx,-ky],
        ], dtype='int'),
        vals = np.array([
            np.exp(1j*phase),
            np.exp(-1j*phase),
        ],dtype='complex'),
    )
    
    artist1.image = color_complex(psi.T, color_phase=check1.value, amp_range=[0,1])
    artist2.image = color_complex(Psi.T, color_phase=check1.value, amp_range=[0,1])

# Activate an observer to receive the interaction.
slider1.observe(slider_change, 'value')
slider2.observe(slider_change, 'value')
slider3.observe(slider_change, 'value')
check1.observe(slider_change, 'value')
slider1.style = {'description_width': '50px'}
slider2.style = {'description_width': '50px'}
slider3.style = {'description_width': '50px'}
check1.style  = {'description_width': '50px'}

# Set up and display the widgets in a grid.
slider_change()
widgets.HBox([canvas1.widget, canvas2.widget, widgets.VBox([slider1, slider2, slider3, check1])])

HBox(children=(VBox(children=(HBox(children=(HBox(layout=Layout(width='0px')), HTML(value="<p style='font-size…

# 2D crystal Lattice

In [151]:
canvas1 = Canvas(height=400, width=400, fig_margin={'top': 0, 'bottom': 0, 'left': 0, 'right': 0})
canvas2 = Canvas(height=400, width=400, fig_margin={'top': 0, 'bottom': 0, 'left': 0, 'right': 0})
canvas1.y_scale.reverse = True
canvas2.y_scale.reverse = True

# Artists do the plotting.
artist1 = ImageArtist(autoadjust_colorscale = False, rgb=True)
artist2 = ImageArtist(autoadjust_colorscale = False, rgb=True)
canvas1.artists = {'image': artist1}
canvas2.artists = {'image': artist2}

slider1 = widgets.IntSlider(description='u length', min=4, max=16, value=6, step=1)
slider2 =  widgets.FloatSlider(description='u angle', min=0, max=180, value=30, step=5)
slider3 = widgets.IntSlider(description='v length', min=4, max=16, value=6, step=1)
slider4 =  widgets.FloatSlider(description='v angle', min=0, max=180, value=120, step=5)

slider5 =  widgets.FloatLogSlider(description='env. size', min=1, max=4, value=64, step=0.1)
slider6 =  widgets.FloatSlider(description='sharpness', min=2, max=16, value=4, step=0.1)

check1 = widgets.Checkbox(value=False,description='Show Phase')

def slider_change(*args):
    u_l = slider1.value
    u_t = slider2.value
    v_l = slider3.value
    v_t = slider4.value
    sigma_envelop = slider5.value
    atom_power = slider6.value    
    
    u_lat = np.array([u_l*np.cos(u_t*np.pi/180), u_l*np.sin(u_t*np.pi/180)])
    v_lat = np.array([v_l*np.cos(v_t*np.pi/180), v_l*np.sin(v_t*np.pi/180)])
    
    Psi, psi = atomic_lattice(
        u_lat = u_lat,
        v_lat = v_lat,
        atom_power=atom_power,
        sigma_envelop=sigma_envelop,
    )
    
    artist1.image = color_complex(psi.T, color_phase=check1.value, amp_range=[0,1])
    artist2.image = color_complex(Psi.T, color_phase=check1.value, amp_range=[0,1])

# Activate an observer to receive the interaction.
slider1.observe(slider_change, 'value')
slider2.observe(slider_change, 'value')
slider3.observe(slider_change, 'value')
slider4.observe(slider_change, 'value')
slider5.observe(slider_change, 'value')
slider6.observe(slider_change, 'value')
check1.observe(slider_change, 'value')
slider1.style = {'description_width': '80px'}
slider2.style = {'description_width': '80px'}
slider3.style = {'description_width': '80px'}
slider4.style = {'description_width': '80px'}
slider5.style = {'description_width': '80px'}
slider6.style = {'description_width': '80px'}
check1.style  = {'description_width': '80px'}

# Set up and display the widgets in a grid.
slider_change()
widgets.HBox([canvas1.widget, canvas2.widget, widgets.VBox([slider1, slider2, slider3, slider4, slider5, slider6, check1])])

HBox(children=(VBox(children=(HBox(children=(HBox(layout=Layout(width='0px')), HTML(value="<p style='font-size…

# 2D super lattice with polarization

In [232]:
canvas1 = Canvas(height=400, width=400, fig_margin={'top': 0, 'bottom': 0, 'left': 0, 'right': 0})
canvas2 = Canvas(height=400, width=400, fig_margin={'top': 0, 'bottom': 0, 'left': 0, 'right': 0})
canvas1.y_scale.reverse = True
canvas2.y_scale.reverse = True

# Artists do the plotting.
artist1 = ImageArtist(autoadjust_colorscale = False, rgb=True)
artist2 = ImageArtist(autoadjust_colorscale = False, rgb=True)
canvas1.artists = {'image': artist1}
canvas2.artists = {'image': artist2}

slider1 = widgets.IntSlider(description='u length', min=4, max=16, value=8, step=1)
slider2 =  widgets.FloatSlider(description='u angle', min=0, max=180, value=30, step=5)
slider3 = widgets.IntSlider(description='v length', min=4, max=16, value=8, step=1)
slider4 =  widgets.FloatSlider(description='v angle', min=0, max=180, value=120, step=5)

slider5 =  widgets.FloatSlider(description='x polar.', min=0, max=1, value=0.0, step=0.01)
slider6 =  widgets.FloatSlider(description='y polar.', min=0, max=1, value=0.2, step=0.01)

check1 = widgets.Checkbox(value=False,description='Show Phase')

def slider_change(*args):
    u_l = slider1.value
    u_t = slider2.value
    v_l = slider3.value
    v_t = slider4.value
    # sigma_envelop = slider5.value
    # atom_power = slider6.value    
    px = slider5.value * 4
    py = slider6.value * 4
    
    u_lat = np.array([u_l*np.cos(u_t*np.pi/180), u_l*np.sin(u_t*np.pi/180)])
    v_lat = np.array([v_l*np.cos(v_t*np.pi/180), v_l*np.sin(v_t*np.pi/180)])
    
    Psi, psi = atomic_lattice_polarization(
        u_lat = u_lat,
        v_lat = v_lat,
        polarization=np.array([px,py]),
    )
    
    artist1.image = color_complex(psi.T, color_phase=check1.value, amp_range=[0,1])
    artist2.image = color_complex(Psi.T, color_phase=check1.value, amp_range=[0,1])

# Activate an observer to receive the interaction.
slider1.observe(slider_change, 'value')
slider2.observe(slider_change, 'value')
slider3.observe(slider_change, 'value')
slider4.observe(slider_change, 'value')
slider5.observe(slider_change, 'value')
slider6.observe(slider_change, 'value')
check1.observe(slider_change, 'value')
slider1.style = {'description_width': '80px'}
slider2.style = {'description_width': '80px'}
slider3.style = {'description_width': '80px'}
slider4.style = {'description_width': '80px'}
slider5.style = {'description_width': '80px'}
slider6.style = {'description_width': '80px'}
check1.style  = {'description_width': '80px'}

# Set up and display the widgets in a grid.
slider_change()
widgets.HBox([canvas1.widget, canvas2.widget, widgets.VBox([slider1, slider2, slider3, slider4, slider5, slider6, check1])])

HBox(children=(VBox(children=(HBox(children=(HBox(layout=Layout(width='0px')), HTML(value="<p style='font-size…