In [1]:
import tensorflow as tf
from numpy import *

In [2]:
# Performs an fftshift operation on the last two dimensions of a 4-D input tensor
def fftshift_tf(data):
    
    dims = tf.shape(data)
    num = dims[3]
    shiftAmt = (num - 1) / 2
    shiftAmt = tf.cast(shiftAmt, int32)
    output = tf.manip.roll(data, shift=shiftAmt, axis=2)
    output = tf.manip.roll(output, shift=shiftAmt, axis=3)
    
    return output

In [3]:
# Performs an ifftshift operation on the last two dimensions of a 4-D input tensor
def ifftshift_tf(data):
    
    dims = tf.shape(data)
    num = dims[3]
    shiftAmt = (num + 1) / 2
    shiftAmt = tf.cast(shiftAmt, int32)
    output = tf.manip.roll(data, shift=shiftAmt, axis=2)
    output = tf.manip.roll(output, shift=shiftAmt, axis=3)
    
    return output

In [4]:
# Generates the phase for a lens based on the focal length variable "f". Other referenced variables are global
def generate_phase():
    
    phase = tf.constant(2*pi, tf.float32) / Lambda * (tf.sqrt(tf.square(x_tensor) + tf.square(y_tensor) + tf.square(f)) - f)
    phase = tf.cast(phase, tf.complex64)
    return phase

In [5]:
# Generates the Fourier space propagator based on the focal length variable "f". Other referenced variables are global
def generate_propagator():
    
    propagator = tf.exp(1j * k_z * tf.cast(f, tf.complex64))
    propagator = ifftshift_tf(propagator)
    
    return propagator

In [6]:
# Propagate an input E-field distribution along the optical axis using the defined propagator
def propagate(input_field, propagator):
    
    # Propagate using tensorflow
    output = tf.ifft2d(tf.fft2d(input_field) * propagator)
    
    return output

In [7]:
# Pass an image through a 4f system
def simulate_4f_system(input_field, kernel):
    
    # Calculate the lens phase
    lens_phase = generate_phase()
    
    # Calculate the propagator
    propagator = generate_propagator()
    
    # Propagate up to the first lens
    before_L1 = propagate(input_field, propagator)
    
    # Apply lens1 and propagate to the filter plane
    before_kernel = propagate(before_L1 * tf.keras.backend.exp(-1j * lens_phase), propagator)
    
    # Apply kernel and propagate to the second lens
    before_L2 = propagate(before_kernel * kernel, propagator)
    
    # Apply lens2 and propagate to the output plane
    output = propagate(before_L2 * tf.keras.backend.exp(-1j * lens_phase), propagator)
    
    # Return output of the 4f optical convolution
    return output

In [8]:
# Perform convolutions with all the kernels

def convolve_with_all_kernels(image, kernel): 
    
    # Zero pad the kernels for subsequent Fourier processing    
    kernel = tf.concat([kernel, tf.constant(zeros((11, 216, 3, 96)), tf.float32)], axis=1)
    kernel = tf.concat([kernel, tf.constant(zeros((216, 227, 3, 96)), tf.float32)], axis=0)
    
    # Align the kernels for Fourier transforming
    kernel = tf.transpose(kernel, perm=[3, 2, 0, 1])
    kernel = tf.cast(kernel, tf.complex64)
    kernel = tf.fft2d(kernel)
    kernel = fftshift_tf(kernel)
    
    # Add an extra dimension for the batch size and duplicate the kernels to apply equally to all images in the batch
    kernel = tf.expand_dims(kernel, axis=0)
    kernel = tf.tile(kernel, multiples=tf.constant([batch_size, 1, 1, 1, 1], tf.int32))
    
    # Add a dimension to the input image tensor to enable convolution with all 96 first layer kernels
    image = tf.expand_dims(image, axis=1)
    image = tf.transpose(image, perm=[0, 1, 4, 2, 3])
    image = tf.tile(image, multiples=tf.constant([1, 96, 1, 1, 1], tf.int32))
    
    # Simulate the 4f system output for all 96 kernels for all color channels and sum the channel outputs
    output = tf.reduce_sum(tf.abs(simulate_4f_system(image, kernel)) ** 2, axis=2)
    
    # Transpose and flip the output for display purposes
    output = tf.transpose(output, perm=[0, 2, 3, 1])
    output = tf.image.flip_left_right(output)
    output = tf.image.flip_up_down(output)
     
    # Convert to float format
    output = tf.cast(output, tf.float32)
    
    # Return the output
    return output

In [10]:
# Define parameters
period = 0.25E-5 # Set the quasi-period, a hyperparameter which fixes the system's width and the pixel size
f = tf.Variable(0.3E-2, tf.float32) # Set the focal length a training variable
multiplier = 1 # Needs to be an odd number
width_pixels = multiplier * 227 # Width of the image in pixels
n_sub = (int) ((width_pixels + 1) / 2)
A = n_sub * period # Set the metasurface radius
n = (int) (2 * n_sub - 1) # Number of pixels along one dimension (full image is nxn)

# Define the spatial grid
xlist_pos = linspace(0, n_sub * period, n_sub)
front = xlist_pos[-(n_sub - 1):]
front = -front[::-1]
xlist = hstack((front, xlist_pos))
n = len(xlist);
xx = kron(xlist, ones((n, 1)))
yy = xx.T
xx = expand_dims(xx, axis=0)
xx = expand_dims(xx, axis=0)
xx = tile(xx, (96, 3, 1, 1))
yy = expand_dims(yy, axis=0)
yy = expand_dims(yy, axis=0)
yy = tile(yy, (96, 3, 1, 1))
x_tensor = tf.constant(xx, tf.float32)
y_tensor = tf.constant(yy, tf.float32)

# Define a constant wavelength tensor
Lambdas = [633E-9, 532E-9, 442E-9] # Set the wavelengths
channel = ones((96, 1, 227, 227))
Lambda = concatenate((Lambdas[0] * channel, Lambdas[1] * channel, Lambdas[2] * channel), axis=1)

# Define the reciprocal space cartesian grid   
k_xlist_pos = 2 * pi * linspace(0, 0.5 * n / (2 * A), n_sub)  
front = k_xlist_pos[-(n_sub - 1):]
front = -front[::-1]
k_xlist = hstack((front, k_xlist_pos))
k_x = kron(k_xlist, ones((n, 1)))
k_y = k_x.T

# Calculate the constant k_z tensor for subsequent propagator calculations
k_z_values = zeros((96, 3, n, n)) + 0j
for i in range(0, 3):
    k = 2 * pi / Lambdas[i]
    k_z_values[:, i, :, :] = sqrt(k ** 2 - k_x ** 2 - k_y ** 2 + 0j) # Add 0j to allow numpy to take square root of negative numbers

k_z = tf.constant(k_z_values, tf.complex64) 

In [11]:
save('xx', xx)
save('yy', yy)
save('Lambda', Lambda)
save('k_z_values', k_z_values)