In [1]:
import os
os.environ['CUPY_ACCELERATORS'] = 'cutensor'
import numpy as np
import cupy as cp
from scipy import fft
import zarr
from matplotlib import pyplot, colors, cm
from PIL import Image
from IPython.display import display
from tqdm.notebook import tqdm, trange

mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

def bytesize_string(nbytes):
    unit =          ['B',  'KiB' ,  'MiB' ,  'GiB' ,  'TiB' ,  'PiB' ,  'EiB' ,  'ZiB' ,  'YiB' ]
    size = np.array([ 1 , 1/2**10, 1/2**20, 1/2**30, 1/2**40, 1/2**50, 1/2**60, 1/2**70, 1/2**80]) * nbytes
    order_of_magnitude = np.argmax(size < 1.0) - 1
    return '{} {}'.format(size[order_of_magnitude], unit[order_of_magnitude])

def array_stats(a):
    print('{} × {} = {} | min: {}, max: {}, avg: {}'.format(a.shape, a.dtype, bytesize_string(a.nbytes), np.amin(a), np.amax(a), np.average(a)))

def display_image(values, color_map=None):
    values = values if color_map is None else cm.get_cmap(color_map)(values)
    display(Image.fromarray((np.flip(np.swapaxes(values, 0, 1), axis=0) * 255.0).astype(np.uint8)))
    
def save_image(name, values, color_map=None):
    values = values if color_map is None else cm.get_cmap(color_map)(values)
    Image.fromarray((np.flip(np.swapaxes(values, 0, 1), axis=0) * 255.0).astype(np.uint8)).save('figures/{}.png'.format(name))

In [2]:
# Focus range
n_f = 100
zeta_F = np.linspace(0.0, 5.0, num=n_f) / 1000

selected_indices = np.arange(n_f, dtype=np.uint8) * 10

In [3]:
n_r = 1024

n_u = 20
n_v = 1080
chunk_v = 40
n_chunk = n_v//chunk_v

n_res = 4

n_a = np.arange(1, n_res+1, dtype=np.uint32) * 12
n_b = np.arange(1, n_res+1, dtype=np.uint32) * 140

In [4]:
reference = zarr.open('data/matrices/reference.zarr', mode='r')

interpolation = []
naive = []
projection = []

for index_res in range(n_res):
    interpolation.append(zarr.open('data/matrices/interpolation-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='r'))
    naive.append(zarr.open('data/matrices/naive-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='r'))
    projection.append(zarr.open('data/matrices/continuous-projection-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='r'))

n_ab = [l.shape[0] for l in interpolation]

In [5]:
%%time
print('Computing assignments to chunk groups...')

# Number of light field groups.
chunk_group = 2
n_group = n_chunk - (chunk_group - 1)

print('\nLight field group size:', chunk_group)
print('Number of light field groups:', n_group)



with tqdm(total=1 + n_res) as pbar:
    
    # Light field chunk incidence on retina pixels.
    chunk_incidence_per_pixel = np.any(reference, axis=(0, 2, 4))
    chunk_group_pixel = np.amax(np.count_nonzero(chunk_incidence_per_pixel, axis=1))

    if chunk_group_pixel > chunk_group:
        print('ERROR! Chunk incidence of {} is bigger than chunk group of {}'.format(chunk_group_pixel, chunk_group))

    # Retina pixel assignment to light field groups.
    pixel_assignment_per_group = np.minimum(np.argmax(chunk_incidence_per_pixel, axis=1), n_group-1).reshape(1, n_r) == np.arange(n_group).reshape(n_group, 1)

    # Retina pixel count per light field group.
    pixel_count_per_group = np.count_nonzero(pixel_assignment_per_group, axis=1).astype(np.uint16)

    # Retina pixel indices per light field group.
    pixel_indices_per_group = [np.nonzero(pixel_assignment)[0] for pixel_assignment in pixel_assignment_per_group]

    print('\nRetina pixel assignment and count per light field group:')
    array_stats(pixel_assignment_per_group)
    array_stats(pixel_count_per_group)

    # Update progress bar.
    pbar.update()



    element_indices_per_group = []

    for index_res in range(n_res):
        print('\n[Display Resolution {} x {}]'.format(n_a[index_res], n_b[index_res]))
        
        # Light field chunk incidence on display elements.
        chunk_incidence_per_element = np.logical_or(np.logical_or(np.any(interpolation[index_res], axis=(1, 3)), np.any(naive[index_res], axis=(1, 3))), np.any(projection[index_res], axis=(1, 3)))
        chunk_group_element = np.amax(np.count_nonzero(chunk_incidence_per_element, axis=1))

        if chunk_group_element > chunk_group:
            print('ERROR! Chunk incidence of {} is bigger than chunk group of {}'.format(chunk_group_element, chunk_group))

        # Display element assignment to light field groups.
        element_assignment_per_group = np.minimum(np.argmax(chunk_incidence_per_element, axis=1), n_group-1).reshape(1, n_ab[index_res]) == np.arange(n_group).reshape(n_group, 1)

        # Display element count per light field group.
        element_count_per_group = np.count_nonzero(element_assignment_per_group, axis=1).astype(np.uint16)

        # Display element indices per light field group.
        element_indices_per_group.append([np.nonzero(element_assignment)[0] for element_assignment in element_assignment_per_group])

        print('Display element assignment and count per light field group:')
        array_stats(element_assignment_per_group)
        array_stats(element_count_per_group)

        # Update progress bar.
        pbar.update()

Computing assignments to chunk groups...

Light field group size: 2
Number of light field groups: 26


  0%|          | 0/5 [00:00<?, ?it/s]


Retina pixel assignment and count per light field group:
(26, 1024) × bool = 26.0 KiB | min: False, max: True, avg: 0.038461538461538464
(26,) × uint16 = 52.0 B | min: 36, max: 60, avg: 39.38461538461539

[Display Resolution 12 x 140]
Display element assignment and count per light field group:
(26, 680) × bool = 17.265625 KiB | min: False, max: True, avg: 0.038461538461538464
(26,) × uint16 = 52.0 B | min: 21, max: 42, avg: 26.153846153846153

[Display Resolution 24 x 280]
Display element assignment and count per light field group:
(26, 2416) × bool = 61.34375 KiB | min: False, max: True, avg: 0.038461538461538464
(26,) × uint16 = 52.0 B | min: 82, max: 156, avg: 92.92307692307692

[Display Resolution 36 x 420]
Display element assignment and count per light field group:
(26, 5198) × bool = 131.98046875 KiB | min: False, max: True, avg: 0.038461538461538464
(26,) × uint16 = 52.0 B | min: 185, max: 330, avg: 199.92307692307693

[Display Resolution 48 x 560]
Display element assignment an

In [6]:
%%time

scenes = ['car', 'chess', 'dragon', 'sponza']
n_scenes = len(scenes)

reference_images = np.zeros((n_f, n_r, n_r, 3), dtype=np.float32)

interpolation_coef = [np.zeros((n_ab[index_res], n_ab[index_res], 3), dtype=np.float32) for index_res in range(n_res)]
naive_coef = [np.zeros((n_ab[index_res], n_ab[index_res], 3), dtype=np.float32) for index_res in range(n_res)]
projection_coef = [np.zeros((n_ab[index_res], n_ab[index_res], 3), dtype=np.float32) for index_res in range(n_res)]

n_iter = 50
rng = cp.random.default_rng()

mse_interpolation =      np.zeros((n_res, n_f), dtype=np.float32)
mse_naive =              np.zeros((n_res, n_f), dtype=np.float32)
mse_continuous = np.zeros((n_iter, n_res, n_f), dtype=np.float32)

for scene in scenes:
    print('Processing the light field samples from scene "{}"...'.format(scene))

    # Load scene.
    light_field = zarr.open('data/' + scene + '_sampled.zarr', mode='r')

    # Load initial values from sampled light field.
    sample_values = np.empty((n_u, chunk_group, chunk_v, n_u, n_chunk, chunk_v, 3), dtype=np.float32)
    sample_values[:, 1:, :] = light_field[:, :chunk_group-1, :]

    with tqdm(total=n_group * n_group * (1 + n_res)) as pbar:
        for idx_group in range(n_group):
            slice_x = slice(idx_group, idx_group + chunk_group)
            
            # Load more values from sampled light field.
            sample_values[:, :-1, :] = sample_values[:, 1:, :]
            sample_values[:, -1, :] = light_field[:, idx_group+chunk_group-1, :]

            for idy_group in range(n_group):
                slice_y = slice(idy_group, idy_group + chunk_group)

                # Load sample values into the GPU for this group.
                sample_values_GPU = cp.array(sample_values[:, :, :, :, slice_y, :].reshape(n_u * chunk_group * chunk_v, n_u * chunk_group * chunk_v, 3))

                # Retina pixels assigned to this group.
                idx_pixel = pixel_indices_per_group[idx_group]
                idy_pixel = pixel_indices_per_group[idy_group]

                # Compute reference image.
                lm_x_GPU = cp.array(reference.oindex[:, idx_pixel, :, slice_x, :].reshape(n_f, -1, n_u * chunk_group * chunk_v))
                lm_y_GPU = cp.array(reference.oindex[:, idy_pixel, :, slice_y, :].reshape(n_f, -1, n_u * chunk_group * chunk_v))
                reference_images[np.ix_(range(n_f), idx_pixel, idy_pixel)] = cp.clip(cp.einsum('fhx,fvy,xyc->fhvc', lm_x_GPU, lm_y_GPU, sample_values_GPU), 0.0, 1.0).get()

                # Update progress bar.
                pbar.update()

                for index_res in range(n_res):
                    # Display elements assigned to this group.
                    idx_element = element_indices_per_group[index_res][idx_group]
                    idy_element = element_indices_per_group[index_res][idy_group]

                    # Compute interpolated coefficients.
                    lm_x_GPU = cp.array(interpolation[index_res].oindex[idx_element, :, slice_x, :].reshape(-1, n_u * chunk_group * chunk_v))
                    lm_y_GPU = cp.array(interpolation[index_res].oindex[idy_element, :, slice_y, :].reshape(-1, n_u * chunk_group * chunk_v))
                    interpolation_coef[index_res][np.ix_(idx_element, idy_element)] = cp.clip(cp.einsum('hx,vy,xyc->hvc', lm_x_GPU, lm_y_GPU, sample_values_GPU), 0.0, 1.0).get()

                    # Compute naive coefficients.
                    lm_x_GPU = cp.array(naive[index_res].oindex[idx_element, :, slice_x, :].reshape(-1, n_u * chunk_group * chunk_v))
                    lm_y_GPU = cp.array(naive[index_res].oindex[idy_element, :, slice_y, :].reshape(-1, n_u * chunk_group * chunk_v))
                    naive_coef[index_res][np.ix_(idx_element, idy_element)] = cp.clip(cp.einsum('hx,vy,xyc->hvc', lm_x_GPU, lm_y_GPU, sample_values_GPU), 0.0, 1.0).get()

                    # Compute continuous projection.
                    lm_x_GPU = cp.array(projection[index_res].oindex[idx_element, :, slice_x, :].reshape(-1, n_u * chunk_group * chunk_v))
                    lm_y_GPU = cp.array(projection[index_res].oindex[idy_element, :, slice_y, :].reshape(-1, n_u * chunk_group * chunk_v))
                    projection_coef[index_res][np.ix_(idx_element, idy_element)] = cp.einsum('hx,vy,xyc->hvc', lm_x_GPU, lm_y_GPU, sample_values_GPU).get()

                    # Update progress bar.
                    pbar.update()

                # Free GPU memory
                del sample_values_GPU
                del lm_x_GPU
                del lm_y_GPU
                mempool.free_all_blocks()
                
        zarr.open('data/{}/reference.zarr'.format(scene), mode='w', shape=(n_f, n_r, n_r, 3), chunks=(1, n_r, n_r, 3), dtype=np.float32)[:] = reference_images
        for index_f in range(n_f):
            save_image('{}/reference/f{}-reference'.format(scene, index_f), reference_images[index_f])

    print('Simulating retinal images...')
    with tqdm(total=n_res * n_f * (2 + n_iter)) as pbar:
        for index_res in range(n_res):
            simulation = cp.array(zarr.open('data/matrices/simulation-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='r'))
            n_ab = simulation.shape[-1]

            # Interpolation coefficients.
            coefficients = cp.array(interpolation_coef[index_res])
            # Interpolation retinal images.
            for index_f in range(n_f):
                retinal_image = cp.tensordot(simulation[index_f], cp.tensordot(simulation[index_f], coefficients, axes=(1, 1)), axes=(1, 1)).get()
                mse_interpolation[index_res, index_f] = np.mean(np.square(reference_images[index_f] - retinal_image))
                save_image('{}/{}x{}/f{}-interpolation'.format(scene, n_a[index_res], n_b[index_res], index_f), retinal_image)
                pbar.update()

            # Naive coefficients.
            coefficients = cp.array(naive_coef[index_res])
            # Naive retinal images.
            for index_f in range(n_f):
                retinal_image = cp.tensordot(simulation[index_f], cp.tensordot(simulation[index_f], coefficients, axes=(1, 1)), axes=(1, 1)).get()
                mse_naive[index_res, index_f] = np.mean(np.square(reference_images[index_f] - retinal_image))
                save_image('{}/{}x{}/f{}-naive'.format(scene, n_a[index_res], n_b[index_res], index_f), retinal_image)
                pbar.update()

            # Continuous coefficients.
            coefficients = 1.0 - rng.random((n_ab, n_ab, 3), dtype=cp.float32)
            proj_GPU = cp.array(projection_coef[index_res])
            autocorrelation = cp.array(zarr.open('data/matrices/continuous-autocorrelation-{}x{}.zarr'.format(n_a[index_res], n_b[index_res]), mode='r'))
            # For each multiplicative rule iteration...
            for iteration in range(n_iter):
                # Apply the autocorrelation linear map on current coefficients.
                temp = cp.tensordot(autocorrelation, cp.tensordot(autocorrelation, coefficients, axes=(1, 1)), axes=(1, 1))
                # Divide the projected coefficients by the autocorrelated coefficients.
                temp = cp.add(temp, cp.finfo(cp.float32).eps, out=temp)
                temp = cp.divide(proj_GPU, temp, out=temp)
                # Update current coefficients with the multiplicative rule.
                coefficients = cp.multiply(coefficients, temp, out=coefficients)
                # Clip coefficients to interval [0, 1].
                coefficients = cp.clip(coefficients, 0.0, 1.0, out=coefficients)

                # Continuous retinal images.
                for index_f in range(n_f):
                    retinal_image = cp.tensordot(simulation[index_f], cp.tensordot(simulation[index_f], coefficients, axes=(1, 1)), axes=(1, 1)).get()
                    mse_continuous[iteration, index_res, index_f] = np.mean(np.square(reference_images[index_f] - retinal_image))
                    if iteration == (n_iter-1):
                        save_image('{}/{}x{}/f{}-continuous'.format(scene, n_a[index_res], n_b[index_res], index_f), retinal_image)
                    pbar.update()

            # Free GPU memory
            del simulation
            del coefficients
            del proj_GPU
            del autocorrelation
            del temp
            mempool.free_all_blocks()

        zarr.open('data/{}/mse-interpolation.zarr'.format(scene), mode='w', shape=(n_res, n_f), chunks=(n_res, n_f), dtype=np.float32)[:] = mse_interpolation
        zarr.open('data/{}/mse-naive.zarr'.format(scene), mode='w', shape=(n_res, n_f), chunks=(n_res, n_f), dtype=np.float32)[:] = mse_naive
        zarr.open('data/{}/mse-continuous.zarr'.format(scene), mode='w', shape=(n_iter, n_res, n_f), chunks=(1, n_res, n_f), dtype=np.float32)[:] = mse_continuous

Processing the light field samples from scene "car"...


  0%|          | 0/3380 [00:00<?, ?it/s]

Simulating retinal images...


  0%|          | 0/1200 [00:00<?, ?it/s]

Processing the light field samples from scene "chess"...


  0%|          | 0/3380 [00:00<?, ?it/s]

Simulating retinal images...


  0%|          | 0/1200 [00:00<?, ?it/s]

Processing the light field samples from scene "dragon"...


  0%|          | 0/3380 [00:00<?, ?it/s]

Simulating retinal images...


  0%|          | 0/1200 [00:00<?, ?it/s]

Processing the light field samples from scene "sponza"...


  0%|          | 0/3380 [00:00<?, ?it/s]

Simulating retinal images...


  0%|          | 0/1200 [00:00<?, ?it/s]

Wall time: 12h 22min 20s
