In [None]:
%load_ext line_profiler

In [None]:
import datashader as ds
import holoviews as hv
import numpy as np
import pandas as pd
from holoviews.operation import decimate
from holoviews.operation.datashader import aggregate, datashade, dynspread, shade

hv.notebook_extension("bokeh")
%matplotlib inline
import functools

import ipywidgets as widgets
import matplotlib.pyplot as plt
import nd2reader
import skimage
from IPython.display import clear_output, display
from ipywidgets import fixed, interact, interact_manual, interactive
from matplotlib.colors import hex2color

# from bokeh.io import push_notebook, show, output_notebook
# from bokeh.layouts import row
# from bokeh.plotting import figure

In [None]:
decimate.max_samples = 1000
dynspread.max_px = 20
dynspread.threshold = 0.5
shade.cmap = "#30a2da"  # to match HV Bokeh default

In [None]:
frames = nd2reader.ND2Reader(
    "/home/jqs1/scratch/fidelity/171018/20171018_TrxnError_ID.nd2"
)

In [None]:
emission_wavelengths = {"MCHERRY": 583, "GFP": 508, "CY5": 670, "BFP": 448, "CFP": 480}

In [None]:
channel_colors = {
    "BF": "#ffffff",
    "MCHERRY": "#e22400",
    "GFP": "#76ba40",
    "CY5": "#e292fe",
    "BFP": "#3a87fd",
}

In [None]:
channels = frames.metadata["channels"]
n_channels = len(channels)
colors = [hex2color(channel_colors[channel]) for channel in channels]

In [None]:
channel_boxes = []
channel_widgets = []
for channel in frames.metadata["channels"]:
    solo_button = widgets.Button(description="S", layout=widgets.Layout(width="10%"))
    enabled_button = widgets.ToggleButton(description=channel, value=True)
    solo_button._button_to_enable = enabled_button
    color_picker = widgets.ColorPicker(concise=True, value=channel_colors[channel])
    channel_box = widgets.HBox([solo_button, enabled_button, color_picker])
    channel_widgets.append([solo_button, enabled_button, color_picker, channel_box])
solo_buttons, enabled_buttons, color_pickers, channel_boxes = zip(*channel_widgets)
channels_box = widgets.VBox(channel_boxes)
t_slider = widgets.IntSlider(
    label="t",
    min=0,
    max=frames.sizes["t"] - 1,
    step=1,
    value=0,
    continuous_update=False,
)
v_slider = widgets.IntSlider(
    min=0, max=frames.sizes["v"] - 1, step=1, value=0, continuous_update=False
)
slider_box = widgets.VBox([v_slider, t_slider])
control_box = widgets.HBox([channels_box, slider_box])
output = widgets.Output()
main_box = widgets.VBox([control_box, output])
display(main_box)

max_val = 2**14


@functools.lru_cache(50)
def get_cached_image(c=None, t=None, v=None):
    return frames.get_frame_2D(c=c, t=t, v=v)


def update_image(change):
    with output:
        clear_output(wait=True)
        plt.figure(figsize=(8, 8))
        channel_enabled = np.array([button.value for button in enabled_buttons])
        channel_imgs = [
            get_cached_image(c=i, t=t_slider.value, v=v_slider.value)
            for i in range(n_channels)
        ]
        scaled_imgs = [
            channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
            for i in range(n_channels)
        ]
        for scaled_img in scaled_imgs:
            np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
        colored_imgs = [scaled_imgs[i] * np.array(colors[i]) for i in range(n_channels)]
        imgs_to_combine = [
            colored_imgs[i] for i in range(n_channels) if channel_enabled[i]
        ]
        if not len(imgs_to_combine):
            imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
        img = imgs_to_combine[0]
        for img2 in imgs_to_combine[1:]:
            img = 1 - (1 - img) * (1 - img2)
        plt.imshow(img)
        plt.show()


update_image(None)

t_slider.observe(update_image, names="value")
v_slider.observe(update_image, names="value")


def update_solo(solo_button):
    if (
        solo_button._button_to_enable.value
        and sum([b.value for b in enabled_buttons]) == 1
    ):
        for enabled_button in enabled_buttons:
            enabled_button.value = True
    else:
        for enabled_button in enabled_buttons:
            enabled_button.value = enabled_button == solo_button._button_to_enable
    # update_image(None)


for solo_button in solo_buttons:
    solo_button.on_click(update_solo)

for enabled_button in enabled_buttons:
    enabled_button.observe(update_image, names="value")
for color_picker in color_pickers:
    color_picker.observe(update_image, names="value")

# ui = interactive(f,
#          t=t_slider,
#          v=v_slider)
# output = plot.children[-1]
# output.layout.height = '700px'
# display(ui);

In [None]:
%lprun -f update_image update_image(None)

In [None]:
channel_imgs = [frames.get_frame_2D(c=i, t=0, v=0) for i in range(5)]

In [None]:
# colored_imgs = [channel_imgs[i][:,:,np.newaxis] / max_val * np.array(colors[i]) for i in range(1,n_channels) if channel_enabled[i]]
scaled_imgs = [
    channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
    for i in range(n_channels)
]
for scaled_img in scaled_imgs:
    np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
colored_imgs = [scaled_imgs[i] * np.array(colors[i]) for i in range(n_channels)]
img = colored_imgs[1]
for i in range(2, n_channels):
    img = 1 - (1 - img) * (1 - colored_imgs[i])
# img = np.sum(colored_imgs, axis=0)# / sum(channel_enabled[1:])
# if channel_enabled[0]:
#    img = 0.7*img + 0.3*colored_imgs[0]
## SCREEN
# img = 1 - (1 - img)*(1 - colored_imgs[0])

In [None]:
plt.imshow(img)

In [None]:
plt.imshow(colored_imgs[4])