In [1]:
import os
os.environ['CUPY_ACCELERATORS'] = 'cutensor'
import numpy as np
import cupy as cp
from scipy import fft
import zarr
from matplotlib import pyplot, colors, cm
from PIL import Image
from IPython.display import display
from tqdm.notebook import tqdm, trange

mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

def bytesize_string(nbytes):
    unit =          ['B',  'KiB' ,  'MiB' ,  'GiB' ,  'TiB' ,  'PiB' ,  'EiB' ,  'ZiB' ,  'YiB' ]
    size = np.array([ 1 , 1/2**10, 1/2**20, 1/2**30, 1/2**40, 1/2**50, 1/2**60, 1/2**70, 1/2**80]) * nbytes
    order_of_magnitude = np.argmax(size < 1.0) - 1
    return '{} {}'.format(size[order_of_magnitude], unit[order_of_magnitude])

def array_stats(a):
    print('{} × {} = {} | min: {}, max: {}, avg: {}'.format(a.shape, a.dtype, bytesize_string(a.nbytes), np.amin(a), np.amax(a), np.average(a)))

def display_image(values, color_map=None):
    values = values if color_map is None else cm.get_cmap(color_map)(values)
    display(Image.fromarray((np.flip(np.swapaxes(values, 0, 1), axis=0) * 255.0).astype(np.uint8)))
    
def save_image(name, values, color_map=None):
    values = values if color_map is None else cm.get_cmap(color_map)(values)
    Image.fromarray((np.flip(np.swapaxes(values, 0, 1), axis=0) * 255.0).astype(np.uint8)).save('figures/{}.png'.format(name))

In [2]:
# Focus range
n_f = 100
zeta_F = np.linspace(0.0, 5.0, num=n_f) / 1000

In [3]:
# Retina plane
z_r = 1.0
n_r = 1024
size_r = 2.0
pitch_r = size_r / n_r

# Field of view
fov_factor = size_r / z_r
fov = 2 * np.arctan(fov_factor / 2) * 180/np.pi

# Pupil plane
z_p = 0.0
size_p = 8.0

In [4]:
# Plane U
z_u = 0.0
n_u = 20
size_u = 10.0
pitch_u = size_u / n_u

# Plane V
z_v = 265.0
n_v = 1080
size_v = 540.0
pitch_v = size_v / n_v

# View chunks.
chunk_v = 40
n_chunk = n_v//chunk_v

print('[Sampled Light Field]')
print('Plane U:    {:4} × {:4.2f} mm = {:5} mm    z_u = {:5} mm'.format(n_u, pitch_u, size_u, z_u))
print('Plane V:    {:4} × {:4.2f} mm = {:5} mm    z_v = {:5} mm'.format(n_v, pitch_v, size_v, z_v))
print('          ({:2}×{:2})'.format(n_chunk, chunk_v))

print('\n[Retina Resolution Estimate with FoV = {:4.1f}°]'.format(fov))
print('Focus at U: {:7.2f}'.format(fov_factor * z_u / pitch_u))
print('Focus at V: {:7.2f}'.format(fov_factor * z_v / pitch_v))
print('Focus at ∞: {:7.2f}'.format(fov_factor * abs(z_u - z_v) / max(pitch_u, pitch_v)))

print('\nNumber of rays: {:,}'.format((n_u * n_v)**2))

[Sampled Light Field]
Plane U:      20 × 0.50 mm =  10.0 mm    z_u =   0.0 mm
Plane V:    1080 × 0.50 mm = 540.0 mm    z_v = 265.0 mm
          (27×40)

[Retina Resolution Estimate with FoV = 90.0°]
Focus at U:    0.00
Focus at V: 1060.00
Focus at ∞: 1060.00

Number of rays: 466,560,000


In [5]:
n_res = 4

# Plane A
z_a = 8.0
n_a = np.arange(1, n_res+1, dtype=np.uint32) * 12
size_a = 24.0
pitch_a = size_a / n_a

# Plane B
z_b = 136.0
n_b = np.arange(1, n_res+1, dtype=np.uint32) * 140
size_b = 280.0
pitch_b = size_b / n_b

for index_res in range(n_res):
    print('\n[Display Resolution {} x {}]'.format(n_a[index_res], n_b[index_res]))
    print('Plane A:    {:4} × {:4.2f} mm = {:5} mm    z_a = {:5} mm'.format(n_a[index_res], pitch_a[index_res], size_a, z_a))
    print('Plane B:    {:4} × {:4.2f} mm = {:5} mm    z_b = {:5} mm'.format(n_b[index_res], pitch_b[index_res], size_b, z_b))
    print('[Retina Resolution Estimate with FoV = {:4.1f}°]'.format(fov))
    print('Focus at A: {}'.format(fov_factor * z_a / pitch_a[index_res]))
    print('Focus at B: {}'.format(fov_factor * z_b / pitch_b[index_res]))
    print('Focus at ∞: {}'.format(fov_factor * np.abs(z_a - z_b) / np.maximum(pitch_a[index_res], pitch_b[index_res])))
    print('Number of elements: {}'.format((n_a[index_res] * n_b[index_res])**2))


[Display Resolution 12 x 140]
Plane A:      12 × 2.00 mm =  24.0 mm    z_a =   8.0 mm
Plane B:     140 × 2.00 mm = 280.0 mm    z_b = 136.0 mm
[Retina Resolution Estimate with FoV = 90.0°]
Focus at A: 8.0
Focus at B: 136.0
Focus at ∞: 128.0
Number of elements: 2822400

[Display Resolution 24 x 280]
Plane A:      24 × 1.00 mm =  24.0 mm    z_a =   8.0 mm
Plane B:     280 × 1.00 mm = 280.0 mm    z_b = 136.0 mm
[Retina Resolution Estimate with FoV = 90.0°]
Focus at A: 16.0
Focus at B: 272.0
Focus at ∞: 256.0
Number of elements: 45158400

[Display Resolution 36 x 420]
Plane A:      36 × 0.67 mm =  24.0 mm    z_a =   8.0 mm
Plane B:     420 × 0.67 mm = 280.0 mm    z_b = 136.0 mm
[Retina Resolution Estimate with FoV = 90.0°]
Focus at A: 24.0
Focus at B: 408.0
Focus at ∞: 384.0
Number of elements: 228614400

[Display Resolution 48 x 560]
Plane A:      48 × 0.50 mm =  24.0 mm    z_a =   8.0 mm
Plane B:     560 × 0.50 mm = 280.0 mm    z_b = 136.0 mm
[Retina Resolution Estimate with FoV = 90.0°]

In [6]:
def sampling_lattice(n, pitch=1.0):
    return (np.arange(n, dtype=np.float32) - (n - 1)/2) * pitch



def intersect_plane(x_1, x_2, z_1, z_2, z_out):
    return (z_2 - z_out)/(z_2 - z_1) * x_1 + (z_1 - z_out)/(z_1 - z_2) * x_2

def intersect_retina(x_1, x_2, z_1, z_2, z_r, zeta_f):
    return z_r * ((z_2 * zeta_f - 1)/(z_2 - z_1) * x_1 + (z_1 * zeta_f - 1)/(z_1 - z_2) * x_2)

def intersect_plane_from_eye(x_p, x_r, z_out, zeta_f):
    return (1 - z_out * zeta_f) * x_p + z_out * x_r



def compute_element_size_p(pitch_a, pitch_b, z_a, z_b):
    return np.abs(z_b/(z_b - z_a) * pitch_a) + np.abs(z_a/(z_a - z_b) * pitch_b)

def compute_element_size_r(pitch_a, pitch_b, z_a, z_b, zeta_f):
    return np.abs((z_b * zeta_f - 1)/(z_b - z_a) * pitch_a) + np.abs((z_a * zeta_f - 1)/(z_a - z_b) * pitch_b)



def compute_phi_GPU(x_r, x_a, x_b, pitch_a, pitch_b, size_p, z_a, z_b, zeta_f):
    # Ratios for projection of plane A and plane B to the pupil plane.
    ratio_a = cp.reciprocal(1.0 - z_a * zeta_f, dtype=cp.float32)
    ratio_b = cp.reciprocal(1.0 - z_b * zeta_f, dtype=cp.float32)
    # Project display samples to the pupil plane.
    x_p_from_a = ratio_a * (x_a - z_a * x_r)
    x_p_from_b = ratio_b * (x_b - z_b * x_r)
    # Projection half-sizes.
    proj_size_a = cp.abs(ratio_a * pitch_a)
    proj_size_b = cp.abs(ratio_b * pitch_b)
    # Projection A.
    lower_bound_a = cp.subtract(x_p_from_a, proj_size_a/2, out=x_p_from_a)
    upper_bound_a = cp.add(lower_bound_a, proj_size_a)
    # Projection B.
    lower_bound_b = cp.subtract(x_p_from_b, proj_size_b/2, out=x_p_from_b)
    upper_bound_b = cp.add(lower_bound_b, proj_size_b)
    # Intersect the projections.
    lower_bound = cp.clip(lower_bound_a, lower_bound_b, upper_bound_b, out=lower_bound_a)
    upper_bound = cp.clip(upper_bound_a, lower_bound_b, upper_bound_b, out=upper_bound_a)
    # Intersect with pupil.
    lower_bound = cp.clip(lower_bound, -size_p/2, size_p/2, out=lower_bound)
    upper_bound = cp.clip(upper_bound, -size_p/2, size_p/2, out=upper_bound)
    # Return the intersection length, normalized by the pupil size.
    return cp.divide(cp.subtract(upper_bound, lower_bound, out=upper_bound), size_p, out=upper_bound)

In [7]:
lattice_r = sampling_lattice(n_r, pitch_r)

lattice_a = [sampling_lattice(n_a[index_res], pitch_a[index_res]) for index_res in range(n_res)]
lattice_b = [sampling_lattice(n_b[index_res], pitch_b[index_res]) for index_res in range(n_res)]

lattice_u = sampling_lattice(n_u, pitch_u)
lattice_v = sampling_lattice(n_v, pitch_v).reshape(n_chunk, chunk_v)

In [8]:
%%time
print('Computing active elements...')

n_ab = []
element_coord_a = []
element_coord_b = []

for index_res in range(n_res):
    print('\n[Display Resolution {} x {} = {}]'.format(n_a[index_res], n_b[index_res], n_a[index_res] * n_b[index_res]))
    
    # Display element projection size on the pupil.
    element_size_p = compute_element_size_p(pitch_a[index_res], pitch_b[index_res], z_a, z_b)

    # Display element coordinates on the pupil.
    element_coord_p = intersect_plane(lattice_a[index_res].reshape(n_a[index_res], 1), lattice_b[index_res].reshape(1, n_b[index_res]), z_a, z_b, z_p)

    # Display element incidence on the pupil.
    element_incidence_on_pupil = np.abs(element_coord_p) <= (size_p + element_size_p)/2

    # Active display elements.
    element_active_a, element_active_b = np.nonzero(element_incidence_on_pupil)
    n_ab.append(len(element_active_a))

    # Active display element coordinates.
    element_coord_a.append(lattice_a[index_res][element_active_a])
    element_coord_b.append(lattice_b[index_res][element_active_b])

    print('Active display elements:', n_ab[index_res])

Computing active elements...

[Display Resolution 12 x 140]
Active display elements: 680

[Display Resolution 24 x 280]
Active display elements: 2416

[Display Resolution 36 x 420]
Active display elements: 5198

[Display Resolution 48 x 560]
Active display elements: 9056
Wall time: 1.96 ms


In [9]:
%%time
print('Computing linear interpolation and naive...')

# Light field sample coodinates on the pupil and on planes A and B.
sample_coord_p = intersect_plane(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_p)
sample_coord_a = intersect_plane(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_a)
sample_coord_b = intersect_plane(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_b)

with tqdm(total=n_res * 2) as pbar:
    for index_res in range(n_res):
        print('\n[Display Resolution {} x {}]'.format(n_a[index_res], n_b[index_res]))
        
        # Display element coordinates on the pupil and on planes U and V.
        element_coord_p = intersect_plane(element_coord_a[index_res], element_coord_b[index_res], z_a, z_b, z_p)
        element_coord_u = intersect_plane(element_coord_a[index_res], element_coord_b[index_res], z_a, z_b, z_u)
        element_coord_v = intersect_plane(element_coord_a[index_res], element_coord_b[index_res], z_a, z_b, z_v)

        # Display element values from light field samples linear interpolation.
        dist_u = np.abs(element_coord_u.reshape(n_ab[index_res], 1, 1, 1) - lattice_u.reshape(1, n_u,       1,       1))
        dist_v = np.abs(element_coord_v.reshape(n_ab[index_res], 1, 1, 1) - lattice_v.reshape(1,   1, n_chunk, chunk_v))
        weight_u = np.maximum(0.0, 1.0 - dist_u / pitch_u)
        weight_v = np.maximum(0.0, 1.0 - dist_v / pitch_v)
        interpolation = weight_u * weight_v

        print('Interpolation Linear Map:')
        print(interpolation.shape)
        #array_stats(interpolation)
        #array_stats(np.count_nonzero(interpolation, axis=(1, 2, 3)))
        #array_stats(np.sum(interpolation, axis=(1, 2, 3)))

        zarr.open('data/matrices/interpolation-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='w',
                  shape=(n_ab[index_res], n_u, n_chunk, chunk_v), chunks=(n_ab[index_res], n_u, 1, chunk_v), dtype=np.float32)[:] = interpolation
        pbar.update()

        # Display element values from light field samples display pre-filtering.
        dist_a = np.abs(sample_coord_a.reshape(1, n_u, n_chunk, chunk_v) - element_coord_a[index_res].reshape(n_ab[index_res], 1, 1, 1))
        dist_b = np.abs(sample_coord_b.reshape(1, n_u, n_chunk, chunk_v) - element_coord_b[index_res].reshape(n_ab[index_res], 1, 1, 1))
        sample_incidence_per_element = np.logical_and(dist_a <= pitch_a[index_res]/2, dist_b <= pitch_b[index_res]/2)
        sample_count_per_element = np.count_nonzero(sample_incidence_per_element, axis=(1, 2, 3)).astype(np.uint16)
        naive = sample_incidence_per_element.astype(np.float32) / np.maximum(sample_count_per_element, 1).reshape(n_ab[index_res], 1, 1, 1)

        print('Naive Linear Map:')
        print(naive.shape)
        #array_stats(naive)
        #array_stats(np.count_nonzero(naive, axis=(1, 2, 3)))
        #array_stats(np.sum(naive, axis=(1, 2, 3)))

        zarr.open('data/matrices/naive-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='w',
                  shape=(n_ab[index_res], n_u, n_chunk, chunk_v), chunks=(n_ab[index_res], n_u, 1, chunk_v), dtype=np.float32)[:] = naive
        pbar.update()

Computing linear interpolation and naive...


  0%|          | 0/8 [00:00<?, ?it/s]


[Display Resolution 12 x 140]
Interpolation Linear Map:
(680, 20, 27, 40)
Naive Linear Map:
(680, 20, 27, 40)

[Display Resolution 24 x 280]
Interpolation Linear Map:
(2416, 20, 27, 40)
Naive Linear Map:
(2416, 20, 27, 40)

[Display Resolution 36 x 420]
Interpolation Linear Map:
(5198, 20, 27, 40)
Naive Linear Map:
(5198, 20, 27, 40)

[Display Resolution 48 x 560]
Interpolation Linear Map:
(9056, 20, 27, 40)
Naive Linear Map:
(9056, 20, 27, 40)
Wall time: 6.06 s


In [10]:
%%time
print('Computing reference linear maps...')

reference = zarr.open('data/matrices/reference.zarr', mode='w', shape=(n_f, n_r, n_u, n_chunk, chunk_v), chunks=(1, n_r, n_u, 1, chunk_v), dtype=np.float32)

# Light field sample incidence on the pupil.
sample_incidence_on_pupil = np.abs(sample_coord_p) <= size_p/2

with tqdm(total=n_f) as pbar:
    # For each focus distance...
    for index_f, zeta_f in enumerate(zeta_F):

        # Light field sample incidence on retina pixels.
        sample_coord_r = intersect_retina(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_r, zeta_f)
        coord_r = sample_coord_r.reshape(1, n_u, n_chunk, chunk_v) - lattice_r.reshape(n_r, 1, 1, 1)
        sample_incidence_per_pixel = np.logical_and(np.logical_and(coord_r > -pitch_r/2, coord_r <= pitch_r/2), sample_incidence_on_pupil)
        sample_count_per_pixel = np.count_nonzero(sample_incidence_per_pixel, axis=(1, 2, 3)).astype(np.uint16)

        # Retina pixel values from light field samples.
        reference[index_f] = sample_incidence_per_pixel.astype(np.float32) / np.maximum(1, sample_count_per_pixel).reshape(n_r, 1, 1, 1)

        # Update progress bar.
        pbar.update()

print('\nReference Linear Map:')
print(reference.shape)
#array_stats(reference)
#array_stats(np.count_nonzero(reference, axis=(2, 3, 4)))
#array_stats(np.sum(reference, axis=(2, 3, 4)))

Computing reference linear maps...


  0%|          | 0/100 [00:00<?, ?it/s]


Reference Linear Map:
(100, 4096, 20, 27, 40)
Wall time: 1min 21s


In [11]:
%%time
print('Computing retinal linear maps...')

with tqdm(total=n_res * n_f * n_r) as pbar:
    # For each resolution...
    for index_res in range(n_res):
        # For each focus distance...
        for index_f, zeta_f in enumerate(zeta_F):

            # Display element incidence on retina pixels.
            element_size_r = compute_element_size_r(pitch_a[index_res], pitch_b[index_res], z_a, z_b, zeta_f)
            element_coord_r = intersect_retina(element_coord_a[index_res], element_coord_b[index_res], z_a, z_b, z_r, zeta_f)
            dist_r = np.abs(element_coord_r.reshape(1, n_ab[index_res]) - lattice_r.reshape(n_r, 1))
            element_incidence_per_pixel = dist_r < (pitch_r + element_size_r)/2
            element_count_per_pixel = np.count_nonzero(element_incidence_per_pixel, axis=1).astype(np.uint16)

            # Load display element coordinates into the GPU.
            coord_a_GPU = cp.array(element_coord_a[index_res])
            coord_b_GPU = cp.array(element_coord_b[index_res])

            # Light field sample incidence on retina pixels.
            sample_coord_r = intersect_retina(lattice_u.reshape(n_u, 1, 1), lattice_v.reshape(1, n_chunk, chunk_v), z_u, z_v, z_r, zeta_f)
            coord_r = sample_coord_r.reshape(1, n_u, n_chunk, chunk_v) - lattice_r.reshape(n_r, 1, 1, 1)
            sample_incidence_per_pixel = np.logical_and(np.logical_and(coord_r > -pitch_r/2, coord_r <= pitch_r/2), sample_incidence_on_pupil)
            sample_count_per_pixel = np.count_nonzero(sample_incidence_per_pixel, axis=(1, 2, 3)).astype(np.uint16)

            # Load light field sample retina coordinates into the GPU.
            coord_r_GPU = cp.array(sample_coord_r)

            # Linear maps for current focus distance.
            simulation = np.zeros((n_r, n_ab[index_res]), dtype=np.float32)
            projection = np.zeros((n_ab[index_res], n_u, n_chunk, chunk_v), dtype=np.float32)
            autocorrelation = np.zeros((n_ab[index_res], n_ab[index_res]), dtype=np.float32)

            # For each retina pixel...
            for index_r in range(n_r):
                # Get incident display elements and incident light field samples.
                element_indices = np.nonzero(element_incidence_per_pixel[index_r])[0]
                sample_indices = np.nonzero(sample_incidence_per_pixel[index_r])

                # Compute phi values for incident display elements and incident light field samples.
                phi_values = compute_phi_GPU(coord_r_GPU[sample_indices], coord_a_GPU[element_indices, np.newaxis], coord_b_GPU[element_indices, np.newaxis], pitch_a[index_res], pitch_b[index_res], size_p, z_a, z_b, zeta_f).get()

                # Simulation: For each incident display element, take the average phi value over the incident light field samples.
                simulation[index_r, element_indices] = np.average(phi_values, axis=1)

                # Projection: Phi values divided by sample count per pixel (times the pixel pitch omitted here).
                projection[element_indices.reshape(-1, 1), sample_indices[0], sample_indices[1], sample_indices[2]] += phi_values / np.maximum(1, sample_count_per_pixel[index_r])

                # Autocorrelation: For each pair of incident display elements, accumulate the average of the product of sampled phi values (times the pixel pitch omitted here).
                autocorrelation[element_indices.reshape(-1, 1), element_indices.reshape(1, -1)] += np.average(phi_values[:, np.newaxis, :] * phi_values[np.newaxis, :, :], axis=2)

                # Update progress bar.
                pbar.update()

            # Store linear maps.
            zarr.open('data/matrices/simulation-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='a',
                      shape=(n_f, n_r, n_ab[index_res]), chunks=(1, n_r, n_ab[index_res]), dtype=np.float32)[index_f] = simulation
            zarr.open('data/matrices/projection-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='a',
                      shape=(n_f, n_ab[index_res], n_u, n_chunk, chunk_v), chunks=(1, n_ab[index_res], n_u, 1, chunk_v), dtype=np.float32)[index_f] = projection
            zarr.open('data/matrices/autocorrelation-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='a',
                      shape=(n_f, n_ab[index_res], n_ab[index_res]), chunks=(1, n_ab[index_res], n_ab[index_res]), dtype=np.float32)[index_f] = autocorrelation

        print('\n[Display Resolution {} x {}]'.format(n_a[index_res], n_b[index_res]))

        print('Simulation Linear Map:')
        print(simulation.shape)
        #array_stats(simulation)
        #array_stats(np.count_nonzero(simulation, axis=2))
        #array_stats(np.sum(simulation, axis=2))

        print('Projection Linear Map:')
        print(projection.shape)
        #array_stats(projection)
        #array_stats(np.count_nonzero(projection, axis=(2, 3, 4)))
        #array_stats(np.sum(projection, axis=(2, 3, 4)))

        print('Autocorrelation Linear Map:')
        print(autocorrelation.shape)
        #array_stats(autocorrelation)
        #array_stats(np.count_nonzero(autocorrelation, axis=2))
        #array_stats(np.sum(autocorrelation, axis=2))

Computing retinal linear maps...


  0%|          | 0/1638400 [00:00<?, ?it/s]

  avg = a.mean(axis)
  ret = um.true_divide(



[Display Resolution 12 x 140]
Simulation Linear Map:
(4096, 680)
Projection Linear Map:
(680, 20, 27, 40)
Autocorrelation Linear Map:
(680, 680)

[Display Resolution 24 x 280]
Simulation Linear Map:
(4096, 2416)
Projection Linear Map:
(2416, 20, 27, 40)
Autocorrelation Linear Map:
(2416, 2416)

[Display Resolution 36 x 420]
Simulation Linear Map:
(4096, 5198)
Projection Linear Map:
(5198, 20, 27, 40)
Autocorrelation Linear Map:
(5198, 5198)

[Display Resolution 48 x 560]
Simulation Linear Map:
(4096, 9056)
Projection Linear Map:
(9056, 20, 27, 40)
Autocorrelation Linear Map:
(9056, 9056)
Wall time: 46min 37s
