In [None]:
import datashader as ds
import holoviews as hv
import numpy as np
import pandas as pd
from holoviews.operation import decimate
from holoviews.operation.datashader import aggregate, datashade, dynspread, shade

hv.notebook_extension("bokeh")
%matplotlib inline
import ipywidgets as widgets
import matplotlib.pyplot as plt
import nd2reader
import skimage
from IPython.display import clear_output, display
from ipywidgets import fixed, interact, interact_manual, interactive
from matplotlib.colors import hex2color

# from colour import wavelength_to_XYZ
# from colour import XYZ_to_sRGB
# from colour.notation.triplet import RGB_to_HEX
# from bokeh.io import push_notebook, show, output_notebook
# from bokeh.layouts import row
# from bokeh.plotting import figure

In [None]:
decimate.max_samples = 1000
dynspread.max_px = 20
dynspread.threshold = 0.5
shade.cmap = "#30a2da"  # to match HV Bokeh default

In [None]:
frames = nd2reader.ND2Reader(
    "/home/jqs1/scratch/fidelity/171018/20171018_TrxnError_ID.nd2"
)

In [None]:
emission_wavelengths = {"MCHERRY": 583, "GFP": 508, "CY5": 670, "BFP": 448, "CFP": 480}

In [None]:
channel_colors = {
    "BF": "#ffffff",
    "MCHERRY": "#e22400",
    "GFP": "#76ba40",
    "CY5": "#e292fe",
    "BFP": "#3a87fd",
}

In [None]:
channels = frames.metadata["channels"]
n_channels = len(channels)
colors = [hex2color(channel_colors[channel]) for channel in channels]

In [None]:
# XYZ_to_sRGB(wavelength_to_XYZ(emission_wavelengths['MCHERRY']))

In [None]:
channel_boxes = []
for channel in frames.metadata["channels"]:
    enabled = widgets.ToggleButton(description=channel, value=True)
    color = widgets.ColorPicker(concise=True, value=channel_colors[channel])
    channel_box = widgets.HBox([enabled, color])
    channel_boxes.append(channel_box)
channels_box = widgets.VBox(channel_boxes)
t_slider = widgets.IntSlider(
    label="t",
    min=0,
    max=frames.sizes["t"] - 1,
    step=1,
    value=0,
    continuous_update=False,
)
v_slider = widgets.IntSlider(
    min=0, max=frames.sizes["v"] - 1, step=1, value=0, continuous_update=False
)
slider_box = widgets.VBox([v_slider, t_slider])
control_box = widgets.HBox([channels_box, slider_box])
output = widgets.Output()
main_box = widgets.VBox([control_box, output])
display(main_box)

max_val = 2**14


def update_image(change):
    with output:
        clear_output(wait=True)
        plt.figure(figsize=(8, 8))
        channel_enabled = [
            channel_box.children[0].value for channel_box in channel_boxes
        ]
        channel_imgs = [
            frames.get_frame_2D(c=i, t=t_slider.value, v=v_slider.value)
            for i in range(5)
        ]
        # img = channel_imgs[0][:,:,np.newaxis] / max_val * np.array(colors[0])[np.newaxis,np.newaxis,:]
        colored_imgs = [
            channel_imgs[i][:, :, np.newaxis] / max_val * np.array(colors[i])
            for i in range(1, n_channels)
            if channel_enabled[i]
        ]
        img = np.sum(colored_imgs[1:], axis=0) / sum(channel_enabled[1:])
        if channel_enabled[0]:
            # img = (img + channel_imgs[0]) / 2
            ## SCREEN
            img = 1 - (1 - img) * (1 - colored_imgs[0])
        plt.imshow(img)
        plt.show()


update_image(None)

t_slider.observe(update_image, names="value")
v_slider.observe(update_image, names="value")
for channel_box in channel_boxes:
    channel_box.children[0].observe(update_image, names="value")
    channel_box.children[1].observe(update_image, names="value")

# ui = interactive(f,
#          t=t_slider,
#          v=v_slider)
# output = plot.children[-1]
# output.layout.height = '700px'
# display(ui);

In [None]:
channel_imgs[0].max()

In [None]:
2**14