In [None]:
import os
os.environ['CUPY_ACCELERATORS'] = 'cutensor'
import numpy as np
import cupy as cp
from scipy import fft
import zarr
from matplotlib import pyplot, colors, cm
from PIL import Image
from IPython.display import display
from tqdm.notebook import tqdm, trange

mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

In [None]:
range_1 = np.array([300.0, 450.0,  900.0], dtype=np.float32)
range_2 = np.array([300.0, 500.0, 1500.0], dtype=np.float32)
range_3 = np.array([300.0, 525.0, 2100.0], dtype=np.float32)
range_4 = np.array([300.0, 550.0, 3300.0], dtype=np.float32)
range_5 = np.array([300.0, 575.0, 6900.0], dtype=np.float32)
range_6 = np.array([300.0, 600.0, np.inf], dtype=np.float32)

scene_distances = np.zeros((4, 3), dtype=np.float32)

scenes = ['car', 'chess', 'dragon', 'sponza']
scene_distances[0] = range_3
scene_distances[1] = range_1 * 0.75
scene_distances[2] = range_1
scene_distances[3] = range_5 * 1.75



scene_selection = 4
res_mult = 4
n_zeta_f = 100



scene = scenes[scene_selection-1]
focus_distances = scene_distances[scene_selection-1]
n_f = len(focus_distances)

print('Focus distances:', focus_distances)

In [None]:
# Retina plane
z_r = 1.0
n_r = 1024
size_r = 2.0
pitch_r = size_r / n_r

# Field of view
fov_factor = size_r / z_r
fov = 2 * np.arctan(fov_factor / 2) * 180/np.pi

# Pupil plane
z_p = 0.0
size_p = 8.0

In [None]:
# Plane U
z_u = 0.0
n_u = 20
size_u = 10.0
pitch_u = size_u / n_u

# Plane V
z_v = 265.0
n_v = 1080
size_v = 540.0
pitch_v = size_v / n_v

# View chunks.
chunk_v = 40
n_chunk = n_v//chunk_v

print('[Sampled Light Field]')
print('Plane U:    {:4} × {:4.2f} mm = {:5} mm    z_u = {:5} mm'.format(n_u, pitch_u, size_u, z_u))
print('Plane V:    {:4} × {:4.2f} mm = {:5} mm    z_v = {:5} mm'.format(n_v, pitch_v, size_v, z_v))
print('          ({:2}×{:2})'.format(n_chunk, chunk_v))

print('\n[Retina Estimate with FoV = {:4.1f}°]'.format(fov))
print('Focus at U: {:7.2f}'.format(fov_factor * z_u / pitch_u))
print('Focus at V: {:7.2f}'.format(fov_factor * z_v / pitch_v))
print('Focus at ∞: {:7.2f}'.format(fov_factor * abs(z_u - z_v) / max(pitch_u, pitch_v)))

print('\nNumber of rays: {:,}'.format((n_u * n_v)**2))

In [None]:
# Plane A
z_a = 8.0
n_a = 12 * res_mult # 48
size_a = 24.0
pitch_a = size_a / n_a

# Plane B
z_b = 136.0
n_b = 140 * res_mult # 560
size_b = 280.0
pitch_b = size_b / n_b

print('[Display Light Field]')
print('Plane A:    {:4} × {:4.2f} mm = {:5} mm    z_a = {:5} mm'.format(n_a, pitch_a, size_a, z_a))
print('Plane B:    {:4} × {:4.2f} mm = {:5} mm    z_b = {:5} mm'.format(n_b, pitch_b, size_b, z_b))

print('\n[Retina Estimate with FoV = {:4.1f}°]'.format(fov))
print('Focus at A: {:7.2f}'.format(fov_factor * z_a / pitch_a))
print('Focus at B: {:7.2f}'.format(fov_factor * z_b / pitch_b))
print('Focus at ∞: {:7.2f}'.format(fov_factor * abs(z_a - z_b) / max(pitch_a, pitch_b)))

print('\nNumber of elements: {:,}'.format((n_a * n_b)**2))

In [None]:
def bytesize_string(nbytes):
    unit =          ['B',  'KiB' ,  'MiB' ,  'GiB' ,  'TiB' ,  'PiB' ,  'EiB' ,  'ZiB' ,  'YiB' ]
    size = np.array([ 1 , 1/2**10, 1/2**20, 1/2**30, 1/2**40, 1/2**50, 1/2**60, 1/2**70, 1/2**80]) * nbytes
    order_of_magnitude = np.argmax(size < 1.0) - 1
    return '{} {}'.format(size[order_of_magnitude], unit[order_of_magnitude])

def array_stats(a):
    print('{} × {} = {} | min: {}, max: {}, avg: {}'.format(a.shape, a.dtype, bytesize_string(a.nbytes), np.amin(a), np.amax(a), np.average(a)))
    
def display_image(values, color_map=None):
    values = values if color_map is None else cm.get_cmap(color_map)(values)
    display(Image.fromarray((np.flip(np.swapaxes(values, 0, 1), axis=0) * 255.0).astype(np.uint8)))

def save_image(name, values, color_map=None):
    values = values if color_map is None else cm.get_cmap(color_map)(values)
    Image.fromarray((np.flip(np.swapaxes(values, 0, 1), axis=0) * 255.0).astype(np.uint8)).save('figures/{}.png'.format(name))

def compute_spectrum(values):
    return fft.fftshift(fft.fft2(np.array(Image.fromarray((values * 255.0).astype(np.uint8)).convert('L'))))

def spectral_log(spectrum):
    return np.log(1.0 + np.abs(spectrum))

def sampling_lattice(n, pitch=1.0):
    return (np.arange(n, dtype=np.float32) - (n - 1)/2) * pitch

def dot(a, b):
    return np.einsum('...i,...i', a, b)

In [None]:
def intersect_plane(x_1, x_2, z_1, z_2, z_out):
    return (z_2 - z_out)/(z_2 - z_1) * x_1 + (z_1 - z_out)/(z_1 - z_2) * x_2

def intersect_retina(x_1, x_2, z_1, z_2, z_r, zeta_f):
    return z_r * ((z_2 * zeta_f - 1)/(z_2 - z_1) * x_1 + (z_1 * zeta_f - 1)/(z_1 - z_2) * x_2)

def intersect_plane_from_eye(x_p, x_r, z_out, zeta_f):
    return (1 - z_out * zeta_f) * x_p + z_out * x_r



def compute_element_size_p(pitch_a, pitch_b, z_a, z_b):
    return np.abs(z_b/(z_b - z_a) * pitch_a) + np.abs(z_a/(z_a - z_b) * pitch_b)

def compute_element_size_r(pitch_a, pitch_b, z_a, z_b, zeta_f):
    return np.abs((z_b * zeta_f - 1)/(z_b - z_a) * pitch_a) + np.abs((z_a * zeta_f - 1)/(z_a - z_b) * pitch_b)



def compute_phi_GPU(x_r, x_a, x_b, pitch_a, pitch_b, size_p, z_a, z_b, zeta_f):
    # Ratios for projection of plane A and plane B to the pupil plane.
    ratio_a = cp.reciprocal(1.0 - z_a * zeta_f, dtype=cp.float32)
    ratio_b = cp.reciprocal(1.0 - z_b * zeta_f, dtype=cp.float32)
    # Project display samples to the pupil plane.
    x_p_from_a = ratio_a * (x_a - z_a * x_r)
    x_p_from_b = ratio_b * (x_b - z_b * x_r)
    # Projection half-sizes.
    proj_size_a = cp.abs(ratio_a * pitch_a)
    proj_size_b = cp.abs(ratio_b * pitch_b)
    # Projection A.
    lower_bound_a = cp.subtract(x_p_from_a, proj_size_a/2, out=x_p_from_a)
    upper_bound_a = cp.add(lower_bound_a, proj_size_a)
    # Projection B.
    lower_bound_b = cp.subtract(x_p_from_b, proj_size_b/2, out=x_p_from_b)
    upper_bound_b = cp.add(lower_bound_b, proj_size_b)
    # Intersect the projections.
    lower_bound = cp.clip(lower_bound_a, lower_bound_b, upper_bound_b, out=lower_bound_a)
    upper_bound = cp.clip(upper_bound_a, lower_bound_b, upper_bound_b, out=upper_bound_a)
    # Intersect with pupil.
    lower_bound = cp.clip(lower_bound, -size_p/2, size_p/2, out=lower_bound)
    upper_bound = cp.clip(upper_bound, -size_p/2, size_p/2, out=upper_bound)
    # Return the intersection length, normalized by the pupil size.
    return cp.divide(cp.subtract(upper_bound, lower_bound, out=upper_bound), size_p, out=upper_bound)

In [None]:
lattice_r = sampling_lattice(n_r, pitch_r)

lattice_a = sampling_lattice(n_a, pitch_a)
lattice_b = sampling_lattice(n_b, pitch_b)

lattice_u = sampling_lattice(n_u, pitch_u)
lattice_v = sampling_lattice(n_v, pitch_v).reshape(n_chunk, chunk_v)

In [None]:
%%time
print('Computing active elements...')

# Display element projection size on the pupil.
element_size_p = compute_element_size_p(pitch_a, pitch_b, z_a, z_b)

# Display element coordinates on the pupil.
element_coord_p = intersect_plane(lattice_a.reshape(n_a, 1), lattice_b.reshape(1, n_b), z_a, z_b, z_p)

# Display element incidence on the pupil.
element_incidence_on_pupil = np.abs(element_coord_p) <= (size_p + element_size_p)/2

# Active display elements.
element_active_a, element_active_b = np.nonzero(element_incidence_on_pupil)
n_ab = len(element_active_a)

# Active display element coordinates.
element_coord_a = lattice_a[element_active_a]
element_coord_b = lattice_b[element_active_b]

print('\nActive display elements:', n_ab)

In [None]:
%%time
print('Computing linear interpolation and naive...')

# Display element coordinates on the pupil and on planes U and V.
element_coord_p = intersect_plane(element_coord_a, element_coord_b, z_a, z_b, z_p)
element_coord_u = intersect_plane(element_coord_a, element_coord_b, z_a, z_b, z_u)
element_coord_v = intersect_plane(element_coord_a, element_coord_b, z_a, z_b, z_v)

# Light field sample coodinates on the pupil and on planes A and B.
sample_coord_p = intersect_plane(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_p)
sample_coord_a = intersect_plane(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_a)
sample_coord_b = intersect_plane(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_b)



# Display element values from light field samples linear interpolation.
dist_u = np.abs(element_coord_u.reshape(n_ab, 1, 1, 1) - lattice_u.reshape(1, n_u,       1,       1))
dist_v = np.abs(element_coord_v.reshape(n_ab, 1, 1, 1) - lattice_v.reshape(1,   1, n_chunk, chunk_v))
weight_u = np.maximum(0.0, 1.0 - dist_u / pitch_u)
weight_v = np.maximum(0.0, 1.0 - dist_v / pitch_v)
interp_linear_map = weight_u * weight_v



# Display element values from light field samples display pre-filtering.
dist_a = np.abs(sample_coord_a.reshape(1, n_u, n_chunk, chunk_v) - element_coord_a.reshape(n_ab, 1, 1, 1))
dist_b = np.abs(sample_coord_b.reshape(1, n_u, n_chunk, chunk_v) - element_coord_b.reshape(n_ab, 1, 1, 1))
sample_incidence_per_element = np.logical_and(dist_a <= pitch_a/2, dist_b <= pitch_b/2)
sample_count_per_element = np.count_nonzero(sample_incidence_per_element, axis=(1, 2, 3)).astype(np.uint16)
naive_linear_map = sample_incidence_per_element.astype(np.float32) / np.maximum(sample_count_per_element, 1).reshape(n_ab, 1, 1, 1)



print('\nInterpolation Linear Map:')
array_stats(interp_linear_map)
array_stats(np.count_nonzero(interp_linear_map, axis=(1, 2, 3)))
array_stats(np.sum(interp_linear_map, axis=(1, 2, 3)))

print('\nNaive Linear Map:')
array_stats(naive_linear_map)
array_stats(np.count_nonzero(naive_linear_map, axis=(1, 2, 3)))
array_stats(np.sum(naive_linear_map, axis=(1, 2, 3)))

In [None]:
%%time
print('Computing discrete focus distance linear maps...')

# Light field sample incidence on the pupil.
sample_incidence_on_pupil = np.abs(sample_coord_p) <= size_p/2

reference_linear_map =       np.zeros((n_f+1,    n_r,   n_u, n_chunk, chunk_v), dtype=np.float32)
simulation_linear_map =      np.zeros((n_f+1,    n_r,   n_ab                 ), dtype=np.float32)
projection_linear_map =      np.zeros((n_f+1,   n_ab,   n_u, n_chunk, chunk_v), dtype=np.float32)
autocorrelation_linear_map = np.zeros((n_f+1,   n_ab,   n_ab                 ), dtype=np.float32)

with tqdm(total=n_f * n_r) as pbar:
    # Load display element coordinates into the GPU.
    coord_a_GPU = cp.array(element_coord_a)
    coord_b_GPU = cp.array(element_coord_b)

    # For each focus distance...
    for index_f, z_f in enumerate(focus_distances):
        zeta_f = 1/z_f

        # Display element incidence on retina pixels.
        element_size_r = compute_element_size_r(pitch_a, pitch_b, z_a, z_b, zeta_f)
        element_coord_r = intersect_retina(element_coord_a, element_coord_b, z_a, z_b, z_r, zeta_f)
        dist_r = np.abs(element_coord_r.reshape(1, n_ab) - lattice_r.reshape(n_r, 1))
        element_incidence_per_pixel = dist_r < (pitch_r + element_size_r)/2
        element_count_per_pixel = np.count_nonzero(element_incidence_per_pixel, axis=1).astype(np.uint16)

        # Light field sample incidence on retina pixels.
        sample_coord_r = intersect_retina(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_r, zeta_f)
        coord_r = sample_coord_r.reshape(1, n_u, n_chunk, chunk_v) - lattice_r.reshape(n_r, 1, 1, 1)
        sample_incidence_per_pixel = np.logical_and(np.logical_and(coord_r > -pitch_r/2, coord_r <= pitch_r/2), sample_incidence_on_pupil)
        sample_count_per_pixel = np.count_nonzero(sample_incidence_per_pixel, axis=(1, 2, 3)).astype(np.uint16)

        # Load light field sample retina coordinates into the GPU.
        coord_r_GPU = cp.array(sample_coord_r)

        # Retina pixel values from light field samples.
        reference_linear_map[index_f] = sample_incidence_per_pixel.astype(np.float32) / np.maximum(1, sample_count_per_pixel).reshape(n_r, 1, 1, 1)

        # For each retina pixel...
        for index_r in range(n_r):

            # Get incident display elements and incident light field samples.
            element_indices = np.nonzero(element_incidence_per_pixel[index_r])[0]
            sample_indices = np.nonzero(sample_incidence_per_pixel[index_r])

            # Compute phi values for incident display elements and incident light field samples.
            phi_values = compute_phi_GPU(coord_r_GPU[sample_indices], coord_a_GPU[element_indices, np.newaxis], coord_b_GPU[element_indices, np.newaxis], pitch_a, pitch_b, size_p, z_a, z_b, zeta_f).get()

            # Simulation: For each incident display element, take the average phi value over the incident light field samples.
            simulation_linear_map[index_f, index_r, element_indices] = np.average(phi_values, axis=1)

            # Projection: Phi values divided by sample count per pixel (times the pixel pitch omitted here).
            projection_linear_map[index_f, element_indices.reshape(-1, 1), sample_indices[0], sample_indices[1], sample_indices[2]] += phi_values / np.maximum(1, sample_count_per_pixel[index_r])

            # Autocorrelation: For each pair of incident display elements, accumulate the average of the product of sampled phi values (times the pixel pitch omitted here).
            autocorrelation_linear_map[index_f, element_indices.reshape(-1, 1), element_indices.reshape(1, -1)] += np.average(phi_values[:, np.newaxis, :] * phi_values[np.newaxis, :, :], axis=2)

            # Update progress bar.
            pbar.update()

    # Free GPU memory
    del coord_a_GPU
    del coord_b_GPU
    del coord_r_GPU
    mempool.free_all_blocks()

print('\nReference Linear Map:')
array_stats(reference_linear_map[:n_f])
array_stats(np.count_nonzero(reference_linear_map[:n_f], axis=(2, 3, 4)))
array_stats(np.sum(reference_linear_map[:n_f], axis=(2, 3, 4)))

print('\nSimulation Linear Map:')
array_stats(simulation_linear_map[:n_f])
array_stats(np.count_nonzero(simulation_linear_map[:n_f], axis=2))
array_stats(np.sum(simulation_linear_map[:n_f], axis=2))

print('\nProjection Linear Map:')
array_stats(projection_linear_map[:n_f])
array_stats(np.count_nonzero(projection_linear_map[:n_f], axis=(2, 3, 4)))
array_stats(np.sum(projection_linear_map[:n_f], axis=(2, 3, 4)))

print('\nAutocorrelation Linear Map:')
array_stats(autocorrelation_linear_map[:n_f])
array_stats(np.count_nonzero(autocorrelation_linear_map[:n_f], axis=2))
array_stats(np.sum(autocorrelation_linear_map[:n_f], axis=2))

In [None]:
%%time
print('Computing continuous focus distance linear maps...')

with tqdm(total=n_zeta_f * n_r) as pbar:
    # Load display element coordinates into the GPU.
    coord_a_GPU = cp.array(element_coord_a)
    coord_b_GPU = cp.array(element_coord_b)

    # For each focus distance...
    for zeta_f in np.linspace(1.0/focus_distances[-1], 1.0/focus_distances[0], num=n_zeta_f):

        # Display element incidence on retina pixels.
        element_size_r = compute_element_size_r(pitch_a, pitch_b, z_a, z_b, zeta_f)
        element_coord_r = intersect_retina(element_coord_a, element_coord_b, z_a, z_b, z_r, zeta_f)
        dist_r = np.abs(element_coord_r.reshape(1, n_ab) - lattice_r.reshape(n_r, 1))
        element_incidence_per_pixel = dist_r < (pitch_r + element_size_r)/2
        element_count_per_pixel = np.count_nonzero(element_incidence_per_pixel, axis=1).astype(np.uint16)

        # Light field sample incidence on retina pixels.
        sample_coord_r = intersect_retina(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_r, zeta_f)
        coord_r = sample_coord_r.reshape(1, n_u, n_chunk, chunk_v) - lattice_r.reshape(n_r, 1, 1, 1)
        sample_incidence_per_pixel = np.logical_and(np.logical_and(coord_r > -pitch_r/2, coord_r <= pitch_r/2), sample_incidence_on_pupil)
        sample_count_per_pixel = np.count_nonzero(sample_incidence_per_pixel, axis=(1, 2, 3)).astype(np.uint16)

        # Load light field sample retina coordinates into the GPU.
        coord_r_GPU = cp.array(sample_coord_r)

        # Retina pixel values from light field samples.
        reference_linear_map[-1] += (sample_incidence_per_pixel.astype(np.float32) / np.maximum(1, sample_count_per_pixel).reshape(n_r, 1, 1, 1)) / n_zeta_f

        # For each retina pixel...
        for index_r in range(n_r):

            # Get incident display elements and incident light field samples.
            element_indices = np.nonzero(element_incidence_per_pixel[index_r])[0]
            sample_indices = np.nonzero(sample_incidence_per_pixel[index_r])

            # Compute phi values for incident display elements and incident light field samples.
            phi_values = compute_phi_GPU(coord_r_GPU[sample_indices], coord_a_GPU[element_indices, np.newaxis], coord_b_GPU[element_indices, np.newaxis], pitch_a, pitch_b, size_p, z_a, z_b, zeta_f).get()

            # Simulation: For each incident display element, take the average phi value over the incident light field samples.
            simulation_linear_map[-1, index_r, element_indices] += np.average(phi_values, axis=1) / n_zeta_f

            # Projection: Phi values divided by sample count per pixel (divided by n_zeta_f and times the pixel pitch omitted here).
            projection_linear_map[-1, element_indices.reshape(-1, 1), sample_indices[0], sample_indices[1], sample_indices[2]] += phi_values / np.maximum(1, sample_count_per_pixel[index_r])

            # Autocorrelation: For each pair of incident display elements, accumulate the average of the product of sampled phi values (divided by n_zeta_f and times the pixel pitch omitted here).
            autocorrelation_linear_map[-1, element_indices.reshape(-1, 1), element_indices.reshape(1, -1)] += np.average(phi_values[:, np.newaxis, :] * phi_values[np.newaxis, :, :], axis=2)

            # Update progress bar.
            pbar.update()

    # Free GPU memory
    del coord_a_GPU
    del coord_b_GPU
    del coord_r_GPU
    mempool.free_all_blocks()

print('\nReference Linear Map:')
array_stats(reference_linear_map[-1])
array_stats(np.count_nonzero(reference_linear_map[-1], axis=(1, 2, 3)))
array_stats(np.sum(reference_linear_map[-1], axis=(1, 2, 3)))

print('\nSimulation Linear Map:')
array_stats(simulation_linear_map[-1])
array_stats(np.count_nonzero(simulation_linear_map[-1], axis=1))
array_stats(np.sum(simulation_linear_map[-1], axis=1))

print('\nProjection Linear Map:')
array_stats(projection_linear_map[-1])
array_stats(np.count_nonzero(projection_linear_map[-1], axis=(1, 2, 3)))
array_stats(np.sum(projection_linear_map[-1], axis=(1, 2, 3)))

print('\nAutocorrelation Linear Map:')
array_stats(autocorrelation_linear_map[-1])
array_stats(np.count_nonzero(autocorrelation_linear_map[-1], axis=1))
array_stats(np.sum(autocorrelation_linear_map[-1], axis=1))

In [None]:
%%time
print('Computing assignments to chunk groups...')

# Light field chunk incidence on display elements.
chunk_incidence_per_element = np.logical_or(np.logical_or(np.any(interp_linear_map, axis=(1, 3)), np.any(naive_linear_map, axis=(1, 3))), np.any(projection_linear_map, axis=(0, 2, 4)))

# Light field chunk incidence on retina pixels.
chunk_incidence_per_pixel = np.any(reference_linear_map, axis=(0, 2, 4))

# Maximum ammount of light field chunks per display element or retina pixel.
chunk_group_element = np.amax(np.count_nonzero(chunk_incidence_per_element, axis=1))
chunk_group_pixel = np.amax(np.count_nonzero(chunk_incidence_per_pixel, axis=1))
chunk_group = max(chunk_group_element, chunk_group_pixel)

# Number of light field groups.
n_group = n_chunk - (chunk_group - 1)

print('\nLight field group size:', chunk_group)
print('Number of light field groups:', n_group)



# Display element assignment to light field groups.
element_assignment_per_group = np.minimum(np.argmax(chunk_incidence_per_element, axis=1), n_group-1).reshape(1, n_ab) == np.arange(n_group).reshape(n_group, 1)

# Display element count per light field group.
element_count_per_group = np.count_nonzero(element_assignment_per_group, axis=1).astype(np.uint16)

# Display element indices per light field group.
element_indices_per_group = [np.nonzero(element_assignment)[0] for element_assignment in element_assignment_per_group]

print('\nDisplay element assignment and count per light field group:')
array_stats(element_assignment_per_group)
array_stats(element_count_per_group)



# Retina pixel assignment to light field groups.
pixel_assignment_per_group = np.minimum(np.argmax(chunk_incidence_per_pixel, axis=1), n_group-1).reshape(1, n_r) == np.arange(n_group).reshape(n_group, 1)

# Retina pixel count per light field group.
pixel_count_per_group = np.count_nonzero(pixel_assignment_per_group, axis=1).astype(np.uint16)

# Retina pixel indices per light field group.
pixel_indices_per_group = [np.nonzero(pixel_assignment)[0] for pixel_assignment in pixel_assignment_per_group]

print('\nRetina pixel assignment and count per light field group:')
array_stats(pixel_assignment_per_group)
array_stats(pixel_count_per_group)

In [None]:
%%time
print('Processing the light field samples...')

element_interp = np.zeros((n_ab, n_ab, 3), dtype=np.float32)
element_naive = np.zeros((n_ab, n_ab, 3), dtype=np.float32)
reference_image = np.zeros((n_f+1, n_r, n_r, 3), dtype=np.float32)
element_projected = np.zeros((n_f+1, n_ab, n_ab, 3), dtype=np.float32)

light_field = zarr.open('data/' + scene + '_sampled.zarr', mode='r')

with tqdm(total=n_group * n_group * (n_f+1)) as pbar:
    # Load initial values from sampled light field.
    sample_values = np.empty((n_u, chunk_group, chunk_v, n_u, n_chunk, chunk_v, 3), dtype=np.float32)
    sample_values[:, 1:, :] = light_field[:, :chunk_group-1, :]

    for idx_group in range(n_group):
        slice_x = slice(idx_group, idx_group + chunk_group)
        
        # Load more values from sampled light field.
        sample_values[:, :-1, :] = sample_values[:, 1:, :]
        sample_values[:, -1, :] = light_field[:, idx_group+chunk_group-1, :]

        # Load horizontal linear maps into the GPU for this column of groups.
        interp_x_GPU = cp.array(interp_linear_map[:, :, slice_x, :].reshape(n_ab, n_u * chunk_group * chunk_v))
        naive_x_GPU = cp.array(naive_linear_map[:, :, slice_x, :].reshape(n_ab, n_u * chunk_group * chunk_v))
        reference_x_GPU = cp.array(reference_linear_map[:, :, :, slice_x, :].reshape(n_f+1, n_r, n_u * chunk_group * chunk_v))
        projection_x_GPU = cp.array(projection_linear_map[:, :, :, slice_x, :].reshape(n_f+1, n_ab, n_u * chunk_group * chunk_v))

        for idy_group in range(n_group):
            slice_y = slice(idy_group, idy_group + chunk_group)

            # Load sample values into the GPU for this group.
            sample_values_GPU = cp.array(sample_values[:, :, :, :, slice_y, :].reshape(n_u * chunk_group * chunk_v, n_u * chunk_group * chunk_v, 3))

            # Load vertical linear maps into the GPU for this group.
            interp_y_GPU = cp.array(interp_linear_map[:, :, slice_y, :].reshape(n_ab, n_u * chunk_group * chunk_v))
            naive_y_GPU = cp.array(naive_linear_map[:, :, slice_y, :].reshape(n_ab, n_u * chunk_group * chunk_v))
            reference_y_GPU = cp.array(reference_linear_map[:, :, :, slice_y, :].reshape(n_f+1, n_r, n_u * chunk_group * chunk_v))
            projection_y_GPU = cp.array(projection_linear_map[:, :, :, slice_y, :].reshape(n_f+1, n_ab, n_u * chunk_group * chunk_v))

            # Display elements assigned to this group.
            idx_element = element_indices_per_group[idx_group]
            idy_element = element_indices_per_group[idy_group]

            # Retina pixels assigned to this group.
            idx_pixel = pixel_indices_per_group[idx_group]
            idy_pixel = pixel_indices_per_group[idy_group]

            # Compute interpolated coefficients.
            element_interp[np.ix_(idx_element, idy_element)] = cp.tensordot(interp_x_GPU[idx_element], cp.tensordot(interp_y_GPU[idy_element], sample_values_GPU, axes=(1, 1)), axes=(1, 1)).get()

            # Compute naive coefficients.
            element_naive[np.ix_(idx_element, idy_element)] = cp.tensordot(naive_x_GPU[idx_element], cp.tensordot(naive_y_GPU[idy_element], sample_values_GPU, axes=(1, 1)), axes=(1, 1)).get()

            # For each discrete focus distance...
            for index_f in range(n_f+1):
                # Compute reference image.
                reference_image[index_f][np.ix_(idx_pixel, idy_pixel)] = cp.tensordot(reference_x_GPU[index_f, idx_pixel], cp.tensordot(reference_y_GPU[index_f, idy_pixel], sample_values_GPU, axes=(1, 1)), axes=(1, 1)).get()

                # Compute projection coefficients.
                element_projected[index_f][np.ix_(idx_element, idy_element)] = cp.tensordot(projection_x_GPU[index_f, idx_element], cp.tensordot(projection_y_GPU[index_f, idy_element], sample_values_GPU, axes=(1, 1)), axes=(1, 1)).get()

                # Update progress bar.
                pbar.update()

            # Free GPU memory
            del sample_values_GPU
            del interp_y_GPU
            del naive_y_GPU
            del reference_y_GPU
            del projection_y_GPU
            mempool.free_all_blocks()

        # Free GPU memory
        del interp_x_GPU
        del naive_x_GPU
        del reference_x_GPU
        del projection_x_GPU
        mempool.free_all_blocks()

    del sample_values

In [None]:
# Naive and projection display element coefficients.
decimation = n_ab//n_r + 1

print(' \n \nInterpolation coefficients:')
array_stats(element_interp)
display_image(element_interp[::decimation, ::decimation])

print(' \n \nNaive coefficients:')
array_stats(element_naive)
display_image(element_naive[::decimation, ::decimation])

print(' \n \nProjected coefficients for continuous focus')
array_stats(element_projected[-1])
display_image(element_projected[-1, ::decimation, ::decimation] / np.amax(element_projected[-1, ::decimation, ::decimation]))

for index_f, z_f in enumerate(focus_distances):
    print(' \n \nProjected coefficients for focus distance at {}'.format(z_f))
    array_stats(element_projected[index_f])
    display_image(element_projected[index_f, ::decimation, ::decimation] / np.amax(element_projected[index_f, ::decimation, ::decimation]))

In [None]:
%%time
print('Computing optimal coefficients...')

rng = cp.random.default_rng()

element_optimal = np.zeros((n_f+1, n_ab, n_ab, 3), dtype=np.float32)

with tqdm(total=(n_f+1) * 10) as pbar:
    # For each focus distance...
    for index_f in range(n_f+1):
        # Load the autocorrelation linear map into the GPU.
        autocorrelation_GPU = cp.array(autocorrelation_linear_map[index_f])

        # Load the projection coefficients into the GPU.
        projected_GPU = cp.array(element_projected[index_f])

        # Initialize display elements with random coefficients.
        element_GPU = 1.0 - rng.random((n_ab, n_ab, 3), dtype=cp.float32)

        # For each multiplicative rule iteration...
        for iteration in range(10):
            # Apply the autocorrelation linear map on current coefficients.
            temp_GPU = cp.tensordot(autocorrelation_GPU, cp.tensordot(autocorrelation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1))
            # Divide the projected coefficients by the autocorrelated coefficients.
            temp_GPU = cp.add(temp_GPU, cp.finfo(cp.float32).eps, out=temp_GPU)
            temp_GPU = cp.divide(projected_GPU, temp_GPU, out=temp_GPU)
            # Update current coefficients with the multiplicative rule.
            element_GPU = cp.multiply(element_GPU, temp_GPU, out=element_GPU)
            # Clip coefficients to interval [0, 1].
            element_GPU = cp.clip(element_GPU, 0.0, 1.0, out=element_GPU)
            # Update progress bar.
            pbar.update()

        # Store coefficients after the last iteration.
        element_optimal[index_f] = element_GPU.get()

        # Free GPU memory
        del autocorrelation_GPU
        del projected_GPU
        del element_GPU
        del temp_GPU
        mempool.free_all_blocks()

In [None]:
# Optimal display element coefficients.
decimation = n_ab//n_r + 1

print(' \n \nOptimal coefficients for continuous focus')
array_stats(element_optimal[-1])
display_image(element_optimal[-1, ::decimation, ::decimation] / np.amax(element_optimal[-1, ::decimation, ::decimation]))

for index_f, z_f in enumerate(focus_distances):
    print(' \n \nOptimal coefficients for focus distance at {}'.format(z_f))
    array_stats(element_optimal[index_f])
    display_image(element_optimal[index_f, ::decimation, ::decimation] / np.amax(element_optimal[index_f, ::decimation, ::decimation]))

In [None]:
%%time
print('Computing discrete focus retina images...')

with tqdm(total=n_f) as pbar:
    # For each focus distance...
    for index_f, z_f in enumerate(focus_distances):
        # Load the simulation linear map into the GPU.
        simulation_GPU = cp.array(simulation_linear_map[index_f])
        
        print(' \n \nReference with observer focus distance at {}'.format(z_f))
        retina_image = reference_image[index_f]
        array_stats(retina_image)
        retina_image = np.clip(retina_image, 0.0, 1.0)
        display_image(retina_image)
        save_image('image-{}-f{}-reference'.format(scene, index_f+1), retina_image)

        reference_spectrum = compute_spectrum(retina_image)
        max_spectral_log = np.amax(spectral_log(reference_spectrum))
        save_image('spectrum-{}-f{}-reference'.format(scene, index_f+1), spectral_log(reference_spectrum) / max_spectral_log, 'Spectral')

        print(' \n \nInterpolation with observer focus distance at {}'.format(z_f))
        element_GPU = cp.array(element_interp)
        retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
        print('Mean Squared Error:', np.mean(np.square(retina_image - reference_image[index_f])))
        array_stats(retina_image)
        retina_image = np.clip(retina_image, 0.0, 1.0)
        display_image(retina_image)
        save_image('image-{}-f{}-{}x{}-interp'.format(scene, index_f+1, n_a, n_b), retina_image)

        spectrum = compute_spectrum(retina_image)
        save_image('spectrum-{}-f{}-{}x{}-interp'.format(scene, index_f+1, n_a, n_b), np.clip(spectral_log(spectrum) / max_spectral_log, 0.0, 1.0), 'Spectral')
        save_image('error-{}-f{}-{}x{}-interp'.format(scene, index_f+1, n_a, n_b), np.clip(spectral_log(spectrum - reference_spectrum) / max_spectral_log, 0.0, 1.0), 'cividis')

        print(' \n \nNaive with observer focus distance at {}'.format(z_f))
        element_GPU = cp.array(element_naive)
        retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
        print('Mean Squared Error:', np.mean(np.square(retina_image - reference_image[index_f])))
        array_stats(retina_image)
        retina_image = np.clip(retina_image, 0.0, 1.0)
        display_image(retina_image)
        save_image('image-{}-f{}-{}x{}-naive'.format(scene, index_f+1, n_a, n_b), retina_image)

        spectrum = compute_spectrum(retina_image)
        save_image('spectrum-{}-f{}-{}x{}-naive'.format(scene, index_f+1, n_a, n_b), np.clip(spectral_log(spectrum) / max_spectral_log, 0.0, 1.0), 'Spectral')
        save_image('error-{}-f{}-{}x{}-naive'.format(scene, index_f+1, n_a, n_b), np.clip(spectral_log(spectrum - reference_spectrum) / max_spectral_log, 0.0, 1.0), 'cividis')

        print(' \n \nOptimal for continuous focus with observer focus distance at {}'.format(z_f))
        element_GPU = cp.array(element_optimal[-1])
        retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
        print('Mean Squared Error:', np.mean(np.square(retina_image - reference_image[index_f])))
        array_stats(retina_image)
        retina_image = np.clip(retina_image, 0.0, 1.0)
        display_image(retina_image)
        save_image('image-{}-f{}-{}x{}-oX'.format(scene, index_f+1, n_a, n_b), retina_image)

        spectrum = compute_spectrum(retina_image)
        save_image('spectrum-{}-f{}-{}x{}-oX'.format(scene, index_f+1, n_a, n_b), np.clip(spectral_log(spectrum) / max_spectral_log, 0.0, 1.0), 'Spectral')
        save_image('error-{}-f{}-{}x{}-oX'.format(scene, index_f+1, n_a, n_b), np.clip(spectral_log(spectrum - reference_spectrum) / max_spectral_log, 0.0, 1.0), 'cividis')

        for target_index_f, target_z_f in enumerate(focus_distances):
            print(' \n \nOptimal for {} with observer focus distance at {}'.format(target_z_f, z_f))
            element_GPU = cp.array(element_optimal[target_index_f])
            retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
            print('Mean Squared Error:', np.mean(np.square(retina_image - reference_image[index_f])))
            array_stats(retina_image)
            retina_image = np.clip(retina_image, 0.0, 1.0)
            display_image(retina_image)
            save_image('image-{}-f{}-{}x{}-o{}'.format(scene, index_f+1, n_a, n_b, target_index_f+1), retina_image)

            spectrum = compute_spectrum(retina_image)
            save_image('spectrum-{}-f{}-{}x{}-o{}'.format(scene, index_f+1, n_a, n_b, target_index_f+1), np.clip(spectral_log(spectrum) / max_spectral_log, 0.0, 1.0), 'Spectral')
            save_image('error-{}-f{}-{}x{}-o{}'.format(scene, index_f+1, n_a, n_b, target_index_f+1), np.clip(spectral_log(spectrum - reference_spectrum) / max_spectral_log, 0.0, 1.0), 'cividis')

        # Update progress bar.
        pbar.update()

        # Free GPU memory
        del simulation_GPU
        del element_GPU
        mempool.free_all_blocks()

In [None]:
%%time
print('Computing continuous focus retina images...')

# Load the simulation linear map into the GPU.
simulation_GPU = cp.array(simulation_linear_map[-1])

print(' \n \nReference with observer continuous focus')
retina_image = reference_image[-1]
array_stats(retina_image)
retina_image = np.clip(retina_image, 0.0, 1.0)
display_image(retina_image)
save_image('image-{}-fX-reference'.format(scene), retina_image)

reference_spectrum = compute_spectrum(retina_image)
max_spectral_log = np.amax(spectral_log(reference_spectrum))
save_image('spectrum-{}-fX-reference'.format(scene), spectral_log(reference_spectrum) / max_spectral_log, 'Spectral')

print(' \n \nInterpolation with observer continuous focus')
element_GPU = cp.array(element_interp)
retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
print('Mean Squared Error:', np.mean(np.square(retina_image - reference_image[-1])))
array_stats(retina_image)
retina_image = np.clip(retina_image, 0.0, 1.0)
display_image(retina_image)
save_image('image-{}-fX-{}x{}-interp'.format(scene, n_a, n_b), retina_image)

spectrum = compute_spectrum(retina_image)
save_image('spectrum-{}-fX-{}x{}-interp'.format(scene, n_a, n_b), np.clip(spectral_log(spectrum) / max_spectral_log, 0.0, 1.0), 'Spectral')
save_image('error-{}-fX-{}x{}-interp'.format(scene, n_a, n_b), np.clip(spectral_log(spectrum - reference_spectrum) / max_spectral_log, 0.0, 1.0), 'cividis')

print(' \n \nNaive with observer continuous focus')
element_GPU = cp.array(element_naive)
retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
print('Mean Squared Error:', np.mean(np.square(retina_image - reference_image[-1])))
array_stats(retina_image)
retina_image = np.clip(retina_image, 0.0, 1.0)
display_image(retina_image)
save_image('image-{}-fX-{}x{}-naive'.format(scene, n_a, n_b), retina_image)

spectrum = compute_spectrum(retina_image)
save_image('spectrum-{}-fX-{}x{}-naive'.format(scene, n_a, n_b), np.clip(spectral_log(spectrum) / max_spectral_log, 0.0, 1.0), 'Spectral')
save_image('error-{}-fX-{}x{}-naive'.format(scene, n_a, n_b), np.clip(spectral_log(spectrum - reference_spectrum) / max_spectral_log, 0.0, 1.0), 'cividis')

print(' \n \nOptimal for continuous focus with observer continuous focus')
element_GPU = cp.array(element_optimal[-1])
retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
print('Mean Squared Error:', np.mean(np.square(retina_image - reference_image[-1])))
array_stats(retina_image)
retina_image = np.clip(retina_image, 0.0, 1.0)
display_image(retina_image)
save_image('image-{}-fX-{}x{}-oX'.format(scene, n_a, n_b), retina_image)

spectrum = compute_spectrum(retina_image)
save_image('spectrum-{}-fX-{}x{}-oX'.format(scene, n_a, n_b), np.clip(spectral_log(spectrum) / max_spectral_log, 0.0, 1.0), 'Spectral')
save_image('error-{}-fX-{}x{}-oX'.format(scene, n_a, n_b), np.clip(spectral_log(spectrum - reference_spectrum) / max_spectral_log, 0.0, 1.0), 'cividis')

for target_index_f, target_z_f in enumerate(focus_distances):
    print(' \n \nOptimal for {} with observer continuous focus'.format(target_z_f))
    element_GPU = cp.array(element_optimal[target_index_f])
    retina_image = cp.tensordot(simulation_GPU, cp.tensordot(simulation_GPU, element_GPU, axes=(1, 1)), axes=(1, 1)).get()
    print('Mean Squared Error:', np.mean(np.square(retina_image - reference_image[-1])))
    array_stats(retina_image)
    retina_image = np.clip(retina_image, 0.0, 1.0)
    display_image(retina_image)
    save_image('image-{}-fX-{}x{}-o{}'.format(scene, n_a, n_b, target_index_f+1), retina_image)

    spectrum = compute_spectrum(retina_image)
    save_image('spectrum-{}-fX-{}x{}-o{}'.format(scene, n_a, n_b, target_index_f+1), np.clip(spectral_log(spectrum) / max_spectral_log, 0.0, 1.0), 'Spectral')
    save_image('error-{}-fX-{}x{}-o{}'.format(scene, n_a, n_b, target_index_f+1), np.clip(spectral_log(spectrum - reference_spectrum) / max_spectral_log, 0.0, 1.0), 'cividis')

# Free GPU memory
del simulation_GPU
del element_GPU
mempool.free_all_blocks()