In [None]:
import functools

import datashader as ds
import holoviews as hv
import ipywidgets as widgets
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import skimage
import zarr
from bokeh.io import output_notebook, push_notebook, show
from bokeh.models import WheelZoomTool
from holoviews.operation import decimate
from holoviews.operation.datashader import (
    aggregate,
    datashade,
    dynspread,
    regrid,
    shade,
)
from holoviews.streams import Stream, param
from IPython.display import clear_output, display
from ipywidgets import fixed, interact, interact_manual, interactive
from matplotlib.colors import hex2color
from numcodecs import Blosc, Delta

# from bokeh.layouts import row
# from bokeh.plotting import figure
from tqdm import tnrange, tqdm_notebook

In [None]:
%load_ext line_profiler
hv.notebook_extension("bokeh")
renderer = hv.renderer("bokeh")
%matplotlib inline

In [None]:
frames_z = zarr.open_array("/home/jqs1/scratch/fidelity/test/171018_.zarr", mode="r")

In [None]:
frames = nd2reader.ND2Reader(
    "/home/jqs1/scratch/fidelity/171018/20171018_TrxnError_ID.nd2"
)

In [None]:
frames_ary = np.stack(
    [
        np.stack(
            [
                np.stack(
                    [
                        frames.get_frame_2D(c=c, t=t, v=v)
                        for t in range(frames.sizes["t"])
                    ]
                )
                for c in range(frames.sizes["c"])
            ]
        )
        for v in tnrange(frames.sizes["v"])
    ]
)

In [None]:
frames_ary.shape

In [None]:
emission_wavelengths = {"MCHERRY": 583, "GFP": 508, "CY5": 670, "BFP": 448, "CFP": 480}

In [None]:
channel_colors = {
    "BF": "#ffffff",
    "MCHERRY": "#e22400",
    "GFP": "#76ba40",
    "CY5": "#e292fe",
    "BFP": "#3a87fd",
}

In [None]:
channels = frames.metadata["channels"]
n_channels = len(channels)
colors = [hex2color(channel_colors[channel]) for channel in channels]

In [None]:
%%output size=250

channel_boxes = []
channel_widgets = []
for channel in frames.metadata["channels"]:
    solo_button = widgets.Button(description="S", layout=widgets.Layout(width="10%"))
    enabled_button = widgets.ToggleButton(description=channel, value=True)
    solo_button._button_to_enable = enabled_button
    color_picker = widgets.ColorPicker(concise=True, value=channel_colors[channel])
    channel_box = widgets.HBox([solo_button, enabled_button, color_picker])
    channel_widgets.append([solo_button, enabled_button, color_picker, channel_box])
solo_buttons, enabled_buttons, color_pickers, channel_boxes = zip(*channel_widgets)
channels_box = widgets.VBox(channel_boxes)
t_slider = widgets.IntSlider(
    label="t",
    min=0,
    max=frames.sizes["t"] - 1,
    step=1,
    value=0,
    continuous_update=False,
)
v_slider = widgets.IntSlider(
    min=0, max=frames.sizes["v"] - 1, step=1, value=0, continuous_update=False
)
slider_box = widgets.VBox([v_slider, t_slider])
control_box = widgets.HBox([channels_box, slider_box])
output = widgets.Output()
main_box = widgets.VBox([control_box, output])
display(main_box)

max_val = 2**14

Frame = Stream.define("Frame", t=0, v=0)
frame = Frame()
DisplaySettings = Stream.define(
    "DisplaySettings", channel_enabled=np.array([True] * n_channels)
)
display_settings = DisplaySettings()


def composite_image(t, v, channel_enabled):
    # def composite_image(t, v):
    # channel_enabled = [True] * n_channels
    # channel_imgs = [frames.get_frame_2D(c=i, t=t, v=v) for i in range(n_channels)]
    channel_imgs = [frames_z[v, c, t, :, :] for c in range(n_channels)]
    scaled_imgs = [
        channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
        for i in range(n_channels)
    ]
    for scaled_img in scaled_imgs:
        np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
    colored_imgs = [scaled_imgs[i] * np.array(colors[i]) for i in range(n_channels)]
    imgs_to_combine = [colored_imgs[i] for i in range(n_channels) if channel_enabled[i]]
    if not len(imgs_to_combine):
        imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
    img = imgs_to_combine[0]
    for img2 in imgs_to_combine[1:]:
        img = 1 - (1 - img) * (1 - img2)
    return hv.RGB(img, bounds=(-1, -1, 1, 1))  # .opts(plot={'size': 250}, tools=[''])


t_slider.observe(lambda change: frame.event(t=change["new"]), names="value")
v_slider.observe(lambda change: frame.event(v=change["new"]), names="value")


def update_enabled_channels(change):
    channel_enabled = np.array([button.value for button in enabled_buttons])
    display_settings.event(channel_enabled=channel_enabled)


def update_solo(solo_button):
    if (
        solo_button._button_to_enable.value
        and sum([b.value for b in enabled_buttons]) == 1
    ):
        for enabled_button in enabled_buttons:
            enabled_button.value = True
    else:
        for enabled_button in enabled_buttons:
            enabled_button.value = enabled_button == solo_button._button_to_enable
    # update_enabled_channels(None)


for solo_button in solo_buttons:
    solo_button.on_click(update_solo)

for enabled_button in enabled_buttons:
    enabled_button.observe(update_enabled_channels, names="value")
# for color_picker in color_pickers:
#    color_picker.observe(update_image, names='value')

# hv.DynamicMap(composite_image, kdims=['t', 'v', 'channel_enabled']).select(t=0,v=0,channel_enabled=np.array([True,False,False,False,False]))
image_viewer = hv.DynamicMap(composite_image, streams=[frame, display_settings])
regrid(image_viewer)