# Spectral Cross Correlation

### Import modules

In [None]:
# Load necessary modules for code
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.patches import Circle
from timeit import default_timer as timer

### Spectral Cross-Correlation with FFT

In [None]:
# A basic function with fast Fourier transforms to compute spectral cross-correlation

# Originally using standard np.fft.fft and ifft but the results weren't nearly as good
# as using the real fft in multiple dimensions. The standard fft algorithms don't properly
# account for multi-dimension arrays and the shift into the spectral domain normally
# introduces imaginary values which can become problematic when trying to understand things
# physically so it's immportant to just use the real components

def spec_cc(p,t):
    # Compute 2D FFT on real inputs
    fft_t = np.fft.rfftn(t)
    fft_p = np.fft.rfftn(p)
    # Take the complex conjugate of the pattern
    fft_p = np.conj(fft_p)
    # Multiply the FFT'd template with conjugated pattern
    product = np.multiply(fft_p,fft_t)
    # Take the inverse FFT of the real 2D result from above
    scores = np.fft.irfftn(product)
    # Return this scores matrix
    return scores

### Load Images

In [None]:
# Define a function to take an image filename, convert it to greyscale
# and output as a 2D array (matrix)
def im_load(file):
    # Read image from filename, convert to array and greyscale
    im = np.array(Image.open(file).convert('L')) # convert('L') is PIL conversion to greyscale
    # Return image array
    return im

### 2D zero-padding

In [None]:
# Define a function to normalize and pad the input arrays in one go.

# Rather than taking the array energy it is quicker and computationally
# easier to just mean shift the arrays (shift values inside array towards
# the mean value of the array) as the normalization. This is done and then
# the normalized array is padded straight after

def mean_padz(p,t):
    # Normalize pattern and template with the method of mean shifting.
    # Find the mean of each array
    p_mean  = np.mean(p)
    t_mean  = np.mean(t)
    # Shift all values in the array by the mean, values near the mean
    # will take on values near zero and contribute less
    p_shift = p - p_mean
    t_shift = t - t_mean
    
    # Now zero pad the shifted arrays.
    # Start by finding the height and width of both arrays
    y_min  = max(p.shape[0],t.shape[0])
    x_min  = max(p.shape[1],t.shape[1])
    # Set the new width and height to fill with zeros by considering
    # an extension that is at least twice the size of the originals.
    # As eluded to in tutes, this works best for powers of 2
    width  = 2**int(np.ceil(np.log2(2*x_min)))
    height = 2**int(np.ceil(np.log2(2*y_min)))
    
    # Create arrays of width and height found above and completely
    # fill with zeros
    p_pad  = np.zeros((height,width))
    t_pad  = np.zeros((height,width))
    # Fill the relevant positions of the original arrays with their
    # meah shifted values, this padding is just to make the arrays
    # equal size rather than extending as an equal box around the images
    # like with the spatial CC
    p_pad[:p.shape[0],:p.shape[1]] = p_shift 
    t_pad[:t.shape[0],:t.shape[1]] = t_shift
    
    # Return normalized and padded arrays
    return p_pad,t_pad

### 2D Best Lag

In [None]:
# This is the same as with spatial, take a score array and the pattern

def b_lag(scores,pattern):
    # Check for the best lag as a position in x and y separetly
    y,x  = np.unravel_index(np.argmax(scores,axis=None),scores.shape)
    # Make sure the best lag points correspond to roughly the centre of the
    # pattern so if we put a marker on this location its right in the middle
    xmid = x + pattern.shape[1]/2
    ymid = y + pattern.shape[0]/2
    return xmid,ymid

## The final function

In [None]:
# Combine written functions to find the rocket man in the
# Where's Wally puzzle

def rocket_man(p_filename,t_filename):
    # Load images in greyscale as arrays
    pattern_gs  = im_load(p_filename)
    template_gs = im_load(t_filename)
    # Load coloured template for final check
    template_col = np.array(Image.open(t_filename))
    # Normalize and zero-pad the arrays
    padded_pattern, padded_template = mean_padz(pattern_gs,template_gs)
    # Compute spectral cross-correlation with FFTs
    score_map = spec_cc(padded_pattern,padded_template)
    # Find location of highest score
    x_lag,y_lag = b_lag(score_map,pattern_gs)
    print("The rocket man is at position:",(x_lag,y_lag))
    
    #----Plot----#
    
    # Create a circle at this point
    c_x,c_y = x_lag,y_lag
    rocket_man = Circle((c_x,c_y),radius=20,color='black')
    
    fig, ax = plt.subplots(1,figsize=(38,12))
    ax.imshow(template_col)
    ax.add_patch(rocket_man)
    plt.show(fig)
    # Return best lag
    return x_lag,y_lag
    

## Run the function

In [None]:
# Run the function and calculate runtime
start = timer()
rocket_man("/Users/justi/Desktop/wallypuzzle_rocket_man.png","/Users/justi/Desktop/wallypuzzle_png.png")
end = timer()
print("Runtime:", end-start,"s") # Runtime in seconds