In [1]:
import numpy as np
import numpy.typing as npt
import cv2
import sys
import os

from skimage.color import xyz2rgb
from matplotlib import pyplot as plt
from ipywidgets import HBox, VBox
import ipywidgets as widgets

In [2]:
WAVELEN_BASE = 400
WAVELEN_INC = 10
NUM_IMAGES = 31

# D65 illuminant intensities for the above wavelengths, since the CAVE data approximates reflectances
D65 = np.array([
    82.7549, 91.486, 93.4318, 86.6823, 104.865, 117.008, 117.812, 114.861, 115.923, 108.811, 109.354,
    107.802, 104.79, 107.689, 104.405, 104.046, 100.0, 96.3342, 95.788, 88.6856, 90.0062, 89.5991,
    87.6987, 83.2886, 83.6992, 80.0268, 80.2146, 82.2778, 78.2842, 69.7213, 71.6091
])
WAVELENS = np.arange(WAVELEN_BASE, WAVELEN_BASE + WAVELEN_INC * NUM_IMAGES, WAVELEN_INC)

In [3]:
# Some color functions

def gauss(x: npt.ArrayLike, mu: float, sigma1: float, sigma2: float) -> npt.NDArray:
    x = np.asarray(x)
    sigma = np.zeros_like(x, dtype=float)
    sigma[x < mu] = sigma1
    sigma[x >= mu] = sigma2
    return np.exp(-0.5 * (x - mu) ** 2 / sigma**2)

def wavelen_to_x(lam: float) -> float:
    return (
        1.056 * gauss(lam, 599.8, 37.9, 31.0)
        + 0.362 * gauss(lam, 442.0, 16.0, 26.7)
        - 0.065 * gauss(lam, 501.1, 20.4, 26.2))

def wavelen_to_y(lam: npt.ArrayLike) -> npt.ArrayLike:
    return 0.821 * gauss(lam, 568.8, 46.9, 40.5) + 0.286 * gauss(lam, 530.9, 16.3, 31.1)

def wavelen_to_z(lam: npt.ArrayLike) -> npt.ArrayLike:
    return 1.217 * gauss(lam, 437.0, 11.8, 36.0) + 0.681 * gauss(lam, 459.0, 26.0, 13.8)

def wavelen_to_xyz(lam: npt.ArrayLike) -> npt.NDArray:
    return np.array([wavelen_to_x(lam), wavelen_to_y(lam), wavelen_to_z(lam)])

xyz_of_wavelen = wavelen_to_xyz(WAVELENS)

In [4]:
def load_image(name):
    global im, image_name
    image_name = name
    im = (
        np.array(
            [
                cv2.imread(f"data/{image_name}/{image_name}_{chan:02d}.png", cv2.IMREAD_UNCHANGED)
                for chan in range(1, 1 + NUM_IMAGES)
            ]
        ).astype(float)
        / 65535.0
    )
    return im
image_name = ''

In [5]:
def calc_image(muls):
    a = (xyz_of_wavelen * (muls * D65)).dot(im.transpose(1, 0, 2)).transpose(1, 2, 0)
    a /= a.max()
    a = xyz2rgb(a)
    return a / a.max()

In [6]:
def plot_image(muls):
    plt.imshow(calc_image(muls))

In [7]:
sliders = [widgets.FloatSlider(min=0.0, max=1.0, value=1.0, continuous_update=False, description=f'{nm} nm')
     for nm in WAVELENS]
enable = widgets.Checkbox(value=True, description="Enable filter")
image_chooser = widgets.Dropdown(options=sorted(os.listdir('data/')))
widget_dict = {'enable': enable, 'image': image_chooser, **{s.description: s for s in sliders}}

In [8]:
def update(**kwargs):
    if kwargs['image'] != image_name:
        load_image(kwargs['image'])
    if kwargs['enable']:
        muls = np.asarray([s.value for s in sliders])
    else:
        muls = np.ones((NUM_IMAGES,))
    plt.figure(figsize=(10, 10))
    plot_image(muls)

demo = widgets.interactive_output(update, widget_dict)

layout = HBox([VBox(sliders[:10]), VBox(sliders[10:20]), VBox(sliders[20:])])
layout = VBox([HBox([demo, VBox([image_chooser, enable])]), layout])

display(layout)

VBox(children=(HBox(children=(Output(), VBox(children=(Dropdown(options=('balloons_ms', 'beads_ms', 'cd_ms', '…