In [None]:
import os
os.environ['CUPY_ACCELERATORS'] = 'cutensor'
import numpy as np
import cupy as cp
from scipy import fft
import zarr
from matplotlib import pyplot, colors, cm
from PIL import Image
from IPython.display import display
from tqdm.notebook import tqdm, trange

mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

In [None]:
scene = 'car'

# Focus distances
focus_distances = np.array([300.0, 600.0, 2000.0], dtype=np.float32)
n_f = len(focus_distances)

In [None]:
# Retina plane
z_r = 1.0
n_r = 1024
size_r = 2.0
pitch_r = size_r / n_r

# Field of view
fov_factor = size_r / z_r
fov = 2 * np.arctan(fov_factor / 2) * 180/np.pi

# Pupil plane
z_p = 0.0
size_p = 8.0

In [None]:
# Plane U
z_u = 0.0
n_u = 20
size_u = 10.0
pitch_u = size_u / n_u

# Plane V
z_v = 265.0
n_v = 1080
size_v = 540.0
pitch_v = size_v / n_v

# View chunks.
chunk_v = 40
n_chunk = n_v//chunk_v

print('[Sampled Light Field]')
print('Plane U:    {:4} × {:4.2f} mm = {:5} mm    z_u = {:5} mm'.format(n_u, pitch_u, size_u, z_u))
print('Plane V:    {:4} × {:4.2f} mm = {:5} mm    z_v = {:5} mm'.format(n_v, pitch_v, size_v, z_v))
print('          ({:2}×{:2})'.format(n_chunk, chunk_v))

print('\n[Retina Estimate with FoV = {:4.1f}°]'.format(fov))
print('Focus at U: {:7.2f}'.format(fov_factor * z_u / pitch_u))
print('Focus at V: {:7.2f}'.format(fov_factor * z_v / pitch_v))
print('Focus at ∞: {:7.2f}'.format(fov_factor * abs(z_u - z_v) / max(pitch_u, pitch_v)))

print('\nNumber of rays: {:,}'.format((n_u * n_v)**2))

In [None]:
# Plane A
z_a = 8.0
n_a = 12*4 # 48
size_a = 24.0
pitch_a = size_a / n_a

# Plane B
z_b = 136.0
n_b = 140*4 # 560
size_b = 280.0
pitch_b = size_b / n_b

print('[Display Light Field]')
print('Plane A:    {:4} × {:4.2f} mm = {:5} mm    z_a = {:5} mm'.format(n_a, pitch_a, size_a, z_a))
print('Plane B:    {:4} × {:4.2f} mm = {:5} mm    z_b = {:5} mm'.format(n_b, pitch_b, size_b, z_b))

print('\n[Retina Estimate with FoV = {:4.1f}°]'.format(fov))
print('Focus at A: {:7.2f}'.format(fov_factor * z_a / pitch_a))
print('Focus at B: {:7.2f}'.format(fov_factor * z_b / pitch_b))
print('Focus at ∞: {:7.2f}'.format(fov_factor * abs(z_a - z_b) / max(pitch_a, pitch_b)))

print('\nNumber of elements: {:,}'.format((n_a * n_b)**2))

In [None]:
def bytesize_string(nbytes):
    unit =          ['B',  'KiB' ,  'MiB' ,  'GiB' ,  'TiB' ,  'PiB' ,  'EiB' ,  'ZiB' ,  'YiB' ]
    size = np.array([ 1 , 1/2**10, 1/2**20, 1/2**30, 1/2**40, 1/2**50, 1/2**60, 1/2**70, 1/2**80]) * nbytes
    order_of_magnitude = np.argmax(size < 1.0) - 1
    return '{} {}'.format(size[order_of_magnitude], unit[order_of_magnitude])

def array_stats(a):
    print('{} × {} = {} | min: {}, max: {}, avg: {}'.format(a.shape, a.dtype, bytesize_string(a.nbytes), np.amin(a), np.amax(a), np.average(a)))
    
def display_image(values):
    display(Image.fromarray((np.flip(np.swapaxes(values, 0, 1), axis=0) * 255.0).astype(np.uint8)))

def spectrum(signal):
    signal = np.array(Image.fromarray((signal * 255.0).astype(np.uint8)).convert('L'))
    return np.log(1.0 + np.abs(fft.fftshift(fft.fft2(signal))))

def sampling_lattice(n, pitch=1.0):
    return (np.arange(n, dtype=np.float32) - (n - 1)/2) * pitch

def dot(a, b):
    return np.einsum('...i,...i', a, b)

In [None]:
#        from eye: (x_p, x_r, z_p, z_r, z_out/z_f)
def intersect_plane(x_1, x_2, z_1, z_2, z_out):
    return (z_2 - z_out)/(z_2 - z_1) * x_1 + (z_1 - z_out)/(z_1 - z_2) * x_2

def intersect_retina(x_1, x_2, z_1, z_2, z_r, z_f):
    return z_r * ((z_2/z_f - 1)/(z_2 - z_1) * x_1 + (z_1/z_f - 1)/(z_1 - z_2) * x_2)



def compute_element_size_p(pitch_a, pitch_b, z_a, z_b):
    return np.abs(z_b/(z_b - z_a) * pitch_a) + np.abs(z_a/(z_a - z_b) * pitch_b)

def compute_element_size_r(pitch_a, pitch_b, z_a, z_b, z_f):
    return np.abs((z_b/z_f - 1)/(z_b - z_a) * pitch_a) + np.abs((z_a/z_f - 1)/(z_a - z_b) * pitch_b)



def compute_phi_GPU(x_r, x_a, x_b, pitch_a, pitch_b, size_p, z_a, z_b, z_f):
    # Ratios for projection of plane A and plane B to the pupil plane.
    ratio_a = cp.reciprocal(1.0 - z_a/z_f, dtype=cp.float32)
    ratio_b = cp.reciprocal(1.0 - z_b/z_f, dtype=cp.float32)
    # Project display samples to the pupil plane.
    x_p_from_a = ratio_a * (x_a - z_a * x_r)
    x_p_from_b = ratio_b * (x_b - z_b * x_r)
    # Projection half-sizes.
    proj_size_a = cp.abs(ratio_a * pitch_a)
    proj_size_b = cp.abs(ratio_b * pitch_b)
    # Projection A.
    lower_bound_a = cp.subtract(x_p_from_a, proj_size_a/2, out=x_p_from_a)
    upper_bound_a = cp.add(lower_bound_a, proj_size_a)
    # Projection B.
    lower_bound_b = cp.subtract(x_p_from_b, proj_size_b/2, out=x_p_from_b)
    upper_bound_b = cp.add(lower_bound_b, proj_size_b)
    # Intersect the projections.
    lower_bound = cp.clip(lower_bound_a, lower_bound_b, upper_bound_b, out=lower_bound_a)
    upper_bound = cp.clip(upper_bound_a, lower_bound_b, upper_bound_b, out=upper_bound_a)
    # Intersect with pupil.
    lower_bound = cp.clip(lower_bound, -size_p/2, size_p/2, out=lower_bound)
    upper_bound = cp.clip(upper_bound, -size_p/2, size_p/2, out=upper_bound)
    # Return the intersection length, normalized by the pupil size.
    return cp.divide(cp.subtract(upper_bound, lower_bound, out=upper_bound), size_p, out=upper_bound)

In [None]:
lattice_r = sampling_lattice(n_r, pitch_r)

lattice_a = sampling_lattice(n_a, pitch_a)
lattice_b = sampling_lattice(n_b, pitch_b)

lattice_u = sampling_lattice(n_u, pitch_u)
lattice_v = sampling_lattice(n_v, pitch_v).reshape(n_chunk, chunk_v)

In [None]:
%%time
print('Computing active elements...')

# Display element coordinates on the pupil, retina and sampling planes.
element_coord_p = intersect_plane( lattice_a.reshape(   n_a, 1), lattice_b.reshape(   1, n_b), z_a, z_b, z_p)
element_coord_r = intersect_retina(lattice_a.reshape(1, n_a, 1), lattice_b.reshape(1, 1, n_b), z_a, z_b, z_r, focus_distances.reshape(n_f, 1, 1))
element_coord_u = intersect_plane( lattice_a.reshape(   n_a, 1), lattice_b.reshape(   1, n_b), z_a, z_b, z_u)
element_coord_v = intersect_plane( lattice_a.reshape(   n_a, 1), lattice_b.reshape(   1, n_b), z_a, z_b, z_v)

# Display element projection size on the pupil and retina.
element_size_p = compute_element_size_p(pitch_a, pitch_b, z_a, z_b)
element_size_r = compute_element_size_r(pitch_a, pitch_b, z_a, z_b, focus_distances)

# Display element incidence on the pupil and retina.
element_incidence_on_pupil = np.abs(element_coord_p) <= (size_p + element_size_p)/2
element_incidence_on_retina = np.abs(element_coord_r) <= (size_r + element_size_r.reshape(n_f, 1, 1))/2

# Display element incidence on the eye.
element_incidence_on_eye = np.logical_and(element_incidence_on_pupil, np.any(element_incidence_on_retina, axis=0))

# Active display elements.
element_active_a, element_active_b = np.nonzero(element_incidence_on_eye)
n_ab = len(element_active_a)

# Active display element coordinates.
element_coord_p = element_coord_p[   element_active_a, element_active_b]
element_coord_r = element_coord_r[:, element_active_a, element_active_b]
element_coord_u = element_coord_u[   element_active_a, element_active_b]
element_coord_v = element_coord_v[   element_active_a, element_active_b]
element_coord_a = lattice_a[element_active_a]
element_coord_b = lattice_b[element_active_b]

print('\nActive display elements:', n_ab)

In [None]:
%%time
print('Computing incidence and counts...')

# Light field sample coodinates on planes A and B.
sample_coord_a = intersect_plane(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_a)
sample_coord_b = intersect_plane(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_b)

# Light field sample incidence on display elements.
coord_a = sample_coord_a.reshape(1, n_u, n_chunk, chunk_v) - element_coord_a.reshape(n_ab, 1, 1, 1)
coord_b = sample_coord_b.reshape(1, n_u, n_chunk, chunk_v) - element_coord_b.reshape(n_ab, 1, 1, 1)
sample_incidence_per_element = np.logical_and(np.logical_and(coord_a > -pitch_a/2, coord_a <= pitch_a/2), np.logical_and(coord_b > -pitch_b/2, coord_b <= pitch_b/2))

# Light field sample count per display element.
sample_count_per_element = np.count_nonzero(sample_incidence_per_element, axis=(1, 2, 3)).astype(np.uint16)

print('\nLight field sample incidence and count per display element:')
array_stats(sample_incidence_per_element)
array_stats(sample_count_per_element)



# Light field sample coodinates on the retina.
sample_coord_r = intersect_retina(lattice_u.reshape(1, n_u, 1, 1), lattice_v.reshape(1, 1, n_chunk, chunk_v), z_u, z_v, z_r, focus_distances.reshape(n_f, 1, 1, 1))

# Light field sample incidence on retina pixels.
coord_r = sample_coord_r.reshape(n_f, 1, n_u, n_chunk, chunk_v) - lattice_r.reshape(1, n_r, 1, 1, 1)
sample_incidence_per_pixel = np.logical_and(coord_r > -pitch_r/2, coord_r <= pitch_r/2)

# Light field sample count per retina pixel.
sample_count_per_pixel = np.count_nonzero(sample_incidence_per_pixel, axis=(2, 3, 4)).astype(np.uint16)

print('\nLight field sample incidence and count per retina pixel:')
array_stats(sample_incidence_per_pixel)
array_stats(sample_count_per_pixel)



# Display element incidence on retina pixels.
element_incidence_per_pixel = np.abs(lattice_r.reshape(1, n_r, 1) - element_coord_r.reshape(n_f, 1, n_ab)) < (pitch_r + element_size_r.reshape(n_f, 1, 1))/2

# Display element count per retina pixel.
element_count_per_pixel = np.count_nonzero(element_incidence_per_pixel, axis=2).astype(np.uint16)

print('\nDisplay element incidence and count per retina pixel:')
array_stats(element_incidence_per_pixel)
array_stats(element_count_per_pixel)

In [None]:
%%time
print('Computing assignments to chunk groups...')

# Light field chunk incidence on display elements.
chunk_incidence_per_element = np.any(sample_incidence_per_element, axis=(1, 3))

# Light field chunk incidence on retina pixels.
chunk_incidence_per_pixel = np.any(sample_incidence_per_pixel, axis=(2, 4))

# Maximum ammount of light field chunks per display element or retina pixel.
chunk_group_element = np.amax(np.count_nonzero(chunk_incidence_per_element, axis=1))
chunk_group_pixel = np.amax(np.count_nonzero(chunk_incidence_per_pixel, axis=2))
chunk_group = max(chunk_group_element, chunk_group_pixel)

# Number of light field groups.
n_group = n_chunk - (chunk_group - 1)

print('\nLight field group size:', chunk_group)
print('Number of light field groups:', n_group)



# Display element assignment to light field groups.
element_assignment_per_group = np.minimum(np.argmax(chunk_incidence_per_element, axis=1), n_group-1).reshape(1, n_ab) == np.arange(n_group).reshape(n_group, 1)

# Retina pixel count per light field group.
element_count_per_group = np.count_nonzero(element_assignment_per_group, axis=1).astype(np.uint16)

# Display element indices per light field group.
element_indices_per_group = [np.nonzero(group)[0] for group in element_assignment_per_group]

print('\nDisplay element assignment and count per light field group:')
array_stats(element_assignment_per_group)
array_stats(element_count_per_group)



# Retina pixel assignment to light field groups.
pixel_assignment_per_group = np.minimum(np.argmax(chunk_incidence_per_pixel, axis=2), n_group-1).reshape(n_f, 1, n_r) == np.arange(n_group).reshape(1, n_group, 1)

# Retina pixel count per light field group.
pixel_count_per_group = np.count_nonzero(pixel_assignment_per_group, axis=2).astype(np.uint16)

# Retina pixel indices per light field group.
pixel_indices_per_group = [[np.nonzero(group)[0] for group in focus] for focus in pixel_assignment_per_group]

print('\nRetina pixel assignment and count per light field group:')
array_stats(pixel_assignment_per_group)
array_stats(pixel_count_per_group)

In [None]:
%%time
print('Computing φ values...')

phi_values = np.zeros((n_f, n_ab, n_u, n_chunk, chunk_v), dtype=np.float32)

with tqdm(total=n_f * n_ab) as pbar:
    # Load coordinates into the GPU.
    coord_r_GPU = cp.array(sample_coord_r)
    coord_a_GPU = cp.array(element_coord_a)
    coord_b_GPU = cp.array(element_coord_b)

    # For each focus distance...
    for index_f, z_f in enumerate(focus_distances):
        # For each display element...
        for element_index in range(n_ab):
            # Compute the phi value for each light field sample.
            phi_GPU = compute_phi_GPU(coord_r_GPU[index_f], coord_a_GPU[element_index], coord_b_GPU[element_index], pitch_a, pitch_b, size_p, z_a, z_b, z_f)
            phi_values[index_f, element_index] = phi_GPU.get()
            # Update progress bar.
            pbar.update()

    # Free GPU memory
    del coord_r_GPU
    del coord_a_GPU
    del coord_b_GPU
    del phi_GPU
    mempool.free_all_blocks()

print('\nφ values:')
array_stats(phi_values)

In [None]:
%%time
print('Computing linear maps...')

reference_linear_map = sample_incidence_per_pixel.astype(np.float32) / np.maximum(sample_count_per_pixel, 1).reshape(n_f, n_r, 1, 1, 1)
naive_linear_map = sample_incidence_per_element.astype(np.float32) / np.maximum(sample_count_per_element, 1).reshape(n_ab, 1, 1, 1)

simulation_linear_map = np.zeros((n_f, n_r, n_ab), dtype=np.float32)
projection_linear_map = np.zeros((n_f, n_ab, n_u, n_chunk, chunk_v), dtype=np.float32)
autocorrelation_linear_map = np.zeros((n_f+1, n_ab, n_ab), dtype=np.float32)

with tqdm(total=n_f * n_r) as pbar:
    # For each focus distance...
    for index_f, z_f in enumerate(focus_distances):
        # For each retina pixel...
        for index_r in range(n_r):
            # Get incident display elements and incident light field samples.
            element_indices = np.nonzero(element_incidence_per_pixel[index_f, index_r])[0]
            sample_indices = np.nonzero(sample_incidence_per_pixel[index_f, index_r])
            # Phi values for incident display elements and incident light field samples.
            incident_phi = phi_values[index_f, element_indices.reshape(-1, 1), sample_indices[0], sample_indices[1], sample_indices[2]]
            # Simulation: For each incident display element, take the average phi value over the incident light field samples.
            simulation_linear_map[index_f, index_r, element_indices] = np.average(incident_phi, axis=1)
            # Projection: Phi values divided by sample count per pixel (times the pixel pitch).
            projection_linear_map[index_f, element_indices.reshape(-1, 1), sample_indices[0], sample_indices[1], sample_indices[2]] += incident_phi / np.maximum(sample_count_per_pixel[index_f, index_r], 1)
            # Autocorrelation: For each pair of incident display elements, accumulate the average of the product of sampled phi values (times the pixel pitch).
            autocorrelation_linear_map[index_f, element_indices.reshape(-1, 1), element_indices.reshape(1, -1)] += np.average(incident_phi[:, np.newaxis, :] * incident_phi[np.newaxis, :, :], axis=2)
            # Update progress bar.
            pbar.update()

autocorrelation_linear_map[-1] = np.average(autocorrelation_linear_map[:-1], axis=0)

print('\nReference Linear Map:')
array_stats(reference_linear_map)
array_stats(np.sum(reference_linear_map, axis=(2, 3, 4)))

print('\nNaive Linear Map:')
array_stats(naive_linear_map)
array_stats(np.sum(naive_linear_map, axis=(1, 2, 3)))

print('\nSimulation Linear Map:')
array_stats(simulation_linear_map)
array_stats(np.sum(simulation_linear_map, axis=2))

print('\nProjection Linear Map:')
array_stats(projection_linear_map)
array_stats(np.sum(projection_linear_map, axis=(2, 3, 4)))

print('\nAutocorrelation Linear Map:')
array_stats(autocorrelation_linear_map)
array_stats(np.sum(autocorrelation_linear_map, axis=2))

In [None]:
%%time
print('Processing the light field samples...')

reference_images = np.zeros((n_f, n_r, n_r, 3), dtype=np.float32)
element_naive = np.zeros((n_ab, n_ab, 3), dtype=np.float32)
element_projected = np.zeros((n_f+1, n_ab, n_ab, 3), dtype=np.float32)

light_field = zarr.open('data/' + scene + '_sampled.zarr', mode='r')

with tqdm(total=n_group * n_group * n_f) as pbar:
    # Load initial values from sampled light field.
    sample_values = np.empty((n_u, chunk_group, chunk_v, n_u, n_chunk, chunk_v, 3), dtype=np.float32)
    sample_values[:, 1:, :] = light_field[:, :chunk_group-1, :]

    for idx_group in range(n_group):
        slice_x = slice(idx_group, idx_group + chunk_group)
        
        # Load more values from sampled light field.
        sample_values[:, :-1, :] = sample_values[:, 1:, :]
        sample_values[:, -1, :] = light_field[:, idx_group+chunk_group-1, :]

        # Load horizontal linear maps into the GPU for this column of groups.
        reference_x_GPU = cp.array(reference_linear_map[:, :, :, slice_x, :].reshape(n_f, n_r, n_u * chunk_group * chunk_v))
        naive_x_GPU = cp.array(naive_linear_map[:, :, slice_x, :].reshape(n_ab, n_u * chunk_group * chunk_v))
        projection_x_GPU = cp.array(projection_linear_map[:, :, :, slice_x, :].reshape(n_f, n_ab, n_u * chunk_group * chunk_v))

        for idy_group in range(n_group):
            slice_y = slice(idy_group, idy_group + chunk_group)

            # Load sample values into the GPU for this group.
            sample_values_GPU = cp.array(sample_values[:, :, :, :, slice_y, :].reshape(n_u * chunk_group * chunk_v, n_u * chunk_group * chunk_v, 3))

            # Load vertical linear maps into the GPU for this group.
            reference_y_GPU = cp.array(reference_linear_map[:, :, :, slice_y, :].reshape(n_f, n_r, n_u * chunk_group * chunk_v))
            naive_y_GPU = cp.array(naive_linear_map[:, :, slice_y, :].reshape(n_ab, n_u * chunk_group * chunk_v))
            projection_y_GPU = cp.array(projection_linear_map[:, :, :, slice_y, :].reshape(n_f, n_ab, n_u * chunk_group * chunk_v))

            # Display elements assigned to this group.
            idx_element = element_indices_per_group[idx_group]
            idy_element = element_indices_per_group[idy_group]
            # Compute naive coefficients.
            element_naive[np.ix_(idx_element, idy_element)] = cp.tensordot(naive_x_GPU[idx_element], cp.tensordot(naive_y_GPU[idy_element], sample_values_GPU, axes=(1, 1)), axes=(1, 1)).get()

            # For each focus distance...
            for index_f, z_f in enumerate(focus_distances):
                # Retina pixels assigned to this group.
                idx_pixel = pixel_indices_per_group[index_f][idx_group]
                idy_pixel = pixel_indices_per_group[index_f][idy_group]
                # Light field samples incident on the assigned retina pixels.
                idx_sample = np.nonzero(np.any(sample_incidence_per_pixel[index_f, idx_pixel], axis=0)[:, slice_x, :].reshape(n_u * chunk_group * chunk_v))[0]
                idy_sample = np.nonzero(np.any(sample_incidence_per_pixel[index_f, idy_pixel], axis=0)[:, slice_y, :].reshape(n_u * chunk_group * chunk_v))[0]
                # Display elements incident on the assigned retina pixels.
                idx_element = np.nonzero(np.any(element_incidence_per_pixel[index_f, idx_pixel], axis=0))[0]
                idy_element = np.nonzero(np.any(element_incidence_per_pixel[index_f, idy_pixel], axis=0))[0]
                # Compute reference image.
                reference_images[index_f][np.ix_(idx_pixel, idy_pixel)] = cp.tensordot(reference_x_GPU[index_f][np.ix_(idx_pixel, idx_sample)], cp.tensordot(reference_y_GPU[index_f][np.ix_(idy_pixel, idy_sample)], sample_values_GPU[np.ix_(idx_sample, idy_sample)], axes=(1, 1)), axes=(1, 1)).get()
                # Compute projection coefficients.
                element_projected[index_f][np.ix_(idx_element, idy_element)] += cp.tensordot(projection_x_GPU[index_f][np.ix_(idx_element, idx_sample)], cp.tensordot(projection_y_GPU[index_f][np.ix_(idy_element, idy_sample)], sample_values_GPU[np.ix_(idx_sample, idy_sample)], axes=(1, 1)), axes=(1, 1)).get()
                # Update progress bar.
                pbar.update()

            # Free GPU memory
            del sample_values_GPU
            del reference_y_GPU
            del naive_y_GPU
            del projection_y_GPU
            mempool.free_all_blocks()

        # Free GPU memory
        del reference_x_GPU
        del naive_x_GPU
        del projection_x_GPU
        mempool.free_all_blocks()

element_projected[-1] = np.average(element_projected[:-1], axis=0)

In [None]:
# Naive and projection display element coefficients.
decimation = n_ab//n_r + 1

print(' \n \nNaive coefficients:')
array_stats(element_naive)
display_image(element_naive[::decimation, ::decimation] / np.amax(element_naive[::decimation, ::decimation]))

print(' \n \nProjected coefficients for all focus distances')
array_stats(element_projected[-1])
display_image(element_projected[-1, ::decimation, ::decimation] / np.amax(element_projected[-1, ::decimation, ::decimation]))

for index_f, z_f in enumerate(focus_distances):
    print(' \n \nProjected coefficients for focus distance at {}'.format(z_f))
    array_stats(element_projected[index_f])
    display_image(element_projected[index_f, ::decimation, ::decimation] / np.amax(element_projected[index_f, ::decimation, ::decimation]))

In [None]:
%%time
print('Computing optimal coefficients...')

rng = cp.random.default_rng()

element_optimal = np.zeros((n_f+1, n_ab, n_ab, 3), dtype=np.float32)

with tqdm(total=(n_f+1) * 10) as pbar:
    # For each focus distance...
    for index_f in range(n_f+1):
        # Load the autocorrelation linear map into the GPU.
        autocorrelation_GPU = cp.array(autocorrelation_linear_map[index_f])

        # Load the projection coefficients into the GPU.
        projected_GPU = cp.array(element_projected[index_f])

        # Initialize display elements with random coefficients.
        element_GPU = 1.0 - rng.random((n_ab, n_ab, 3), dtype=cp.float32)

        # For each multiplicative rule iteration...
        for iteration in range(10):
            # Apply the autocorrelation linear map on current coefficients.
            temp_GPU = cp.tensordot(autocorrelation_GPU, cp.tensordot(autocorrelation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1))
            # Divide the projected coefficients by the autocorrelated coefficients.
            temp_GPU = cp.add(temp_GPU, cp.finfo(cp.float32).eps, out=temp_GPU)
            temp_GPU = cp.divide(projected_GPU, temp_GPU, out=temp_GPU)
            # Update current coefficients with the multiplicative rule.
            element_GPU = cp.multiply(element_GPU, temp_GPU, out=element_GPU)
            # Clip coefficients to interval [0, 1].
            element_GPU = cp.clip(element_GPU, 0.0, 1.0, out=element_GPU)
            # Update progress bar.
            pbar.update()

        # Store coefficients after the last iteration.
        element_optimal[index_f] = element_GPU.get()

        # Free GPU memory
        del autocorrelation_GPU
        del projected_GPU
        del element_GPU
        del temp_GPU
        mempool.free_all_blocks()

In [None]:
# Optimal display element coefficients.
decimation = n_ab//n_r + 1

print(' \n \nOptimal coefficients for all focus distances')
array_stats(element_optimal[-1])
display_image(element_optimal[-1, ::decimation, ::decimation] / np.amax(element_optimal[-1, ::decimation, ::decimation]))

for index_f, z_f in enumerate(focus_distances):
    print(' \n \nOptimal coefficients for focus distance at {}'.format(z_f))
    array_stats(element_optimal[index_f])
    display_image(element_optimal[index_f, ::decimation, ::decimation] / np.amax(element_optimal[index_f, ::decimation, ::decimation]))

In [None]:
%%time
print('Computing retina images...')

reference_spectrum = np.zeros((n_f, n_r, n_r), dtype=np.float64)
naive_spectrum = np.zeros((n_f, n_r, n_r), dtype=np.float64)
generalised_spectrum = np.zeros((n_f, n_r, n_r), dtype=np.float64)
specialised_spectrum = np.zeros((n_f, n_r, n_r), dtype=np.float64)

with tqdm(total=n_f) as pbar:
    # For each focus distance...
    for index_f, z_f in enumerate(focus_distances):
        # Load the simulation linear map into the GPU.
        simulation_GPU = cp.array(simulation_linear_map[index_f])
        
        print(' \n \nReference with observer focus distance at {}'.format(z_f))
        retina_image = reference_images[index_f]
        array_stats(retina_image)
        retina_image = np.clip(retina_image, 0.0, 1.0)
        display_image(retina_image)

        reference_spectrum[index_f] = spectrum(retina_image)

        print(' \n \nNaive with observer focus distance at {}'.format(z_f))
        element_GPU = cp.array(element_naive)
        retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
        array_stats(retina_image)
        retina_image = np.clip(retina_image, 0.0, 1.0)
        display_image(retina_image)

        naive_spectrum[index_f] = spectrum(retina_image)

        print(' \n \nOptimal for all distances with observer focus distance at {}'.format(z_f))
        element_GPU = cp.array(element_optimal[-1])
        retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
        array_stats(retina_image)
        retina_image = np.clip(retina_image, 0.0, 1.0)
        display_image(retina_image)

        generalised_spectrum[index_f] = spectrum(retina_image)

        for target_index_f, target_z_f in enumerate(focus_distances):
            print(' \n \nOptimal for {} with observer focus distance at {}'.format(target_z_f, z_f))
            element_GPU = cp.array(element_optimal[target_index_f])
            retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
            array_stats(retina_image)
            retina_image = np.clip(retina_image, 0.0, 1.0)
            display_image(retina_image)

            if index_f == target_index_f:
                specialised_spectrum[index_f] = spectrum(retina_image)

        # Update progress bar.
        pbar.update()

        # Free GPU memory
        del simulation_GPU
        del element_GPU
        mempool.free_all_blocks()

In [None]:
for index_f, z_f in enumerate(focus_distances):
    print(' \n \nReference spectrum for focus distance at {}'.format(z_f))
    array_stats(reference_spectrum[index_f])
    display_image(reference_spectrum[index_f] / np.amax(reference_spectrum[index_f]))

    print(' \n \nNaive spectrum for focus distance at {}'.format(z_f))
    array_stats(naive_spectrum[index_f])
    display_image(naive_spectrum[index_f] / np.amax(naive_spectrum[index_f]))

    print(' \n \nGeneralised spectrum for focus distance at {}'.format(z_f))
    array_stats(generalised_spectrum[index_f])
    display_image(generalised_spectrum[index_f] / np.amax(generalised_spectrum[index_f]))

    print(' \n \nSpecialised spectrum for focus distance at {}'.format(z_f))
    array_stats(specialised_spectrum[index_f])
    display_image(specialised_spectrum[index_f] / np.amax(specialised_spectrum[index_f]))

In [None]:
spectral_error = np.abs(np.stack((specialised_spectrum, generalised_spectrum, naive_spectrum), axis=3) - reference_spectrum.reshape(n_f, n_r, n_r, 1))

for index_f, z_f in enumerate(focus_distances):
    print(' \n \nNaive spectral error for focus distance at {}'.format(z_f))
    array_stats(spectral_error[index_f, :, :, 2])
    display_image(spectral_error[index_f, :, :, 2] / np.amax(spectral_error))

    print(' \n \nGeneralised spectral error for focus distance at {}'.format(z_f))
    array_stats(spectral_error[index_f, :, :, 1])
    display_image(spectral_error[index_f, :, :, 1] / np.amax(spectral_error))

    print(' \n \nSpecialised spectral error for focus distance at {}'.format(z_f))
    array_stats(spectral_error[index_f, :, :, 0])
    display_image(spectral_error[index_f, :, :, 0] / np.amax(spectral_error))

    print(' \n \n[R: Specialised, G: Generalised, B: Naive] spectral error for focus distance at {}'.format(z_f))
    array_stats(spectral_error[index_f])
    display_image(spectral_error[index_f] / np.amax(spectral_error))