# Imports

In [None]:
import functools
import itertools
import operator
import warnings
from collections import Counter, defaultdict
from collections.abc import Mapping, Sequence
from copy import copy, deepcopy
from functools import partial, reduce
from itertools import zip_longest

import datashader as ds
import holoviews as hv
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import peakutils
import scipy.interpolate
import scipy.stats
import skimage
import skimage.morphology

# from bokeh.layouts import row
# from bokeh.plotting import figure
import tqdm
import zarr
from bokeh.io import output_notebook, push_notebook, show
from bokeh.models import WheelZoomTool
from holoviews.operation import decimate
from holoviews.operation.datashader import (
    aggregate,
    datashade,
    dynspread,
    regrid,
    shade,
)
from holoviews.streams import Stream, param
from IPython.display import clear_output, display
from ipywidgets import fixed, interact, interact_manual, interactive
from matplotlib.colors import hex2color
from more_itertools import rstrip
from tqdm import tnrange, tqdm, tqdm_notebook

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from processing import *
from trench_detection import *
from trench_detection import _label_for_trenches
from trench_segmentation import *
from util import *

In [None]:
%load_ext line_profiler
hv.notebook_extension("bokeh")
# renderer = hv.renderer('bokeh')
%matplotlib inline
tqdm.monitor_interval = 0

# Config

In [None]:
# warnings.filterwarnings('ignore', 'Conversion of the second argument of issubdtype')
# warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings("default", category=FutureWarning)

In [None]:
# store = zarr.DirectoryStore('/home/jqs1/scratch/fidelity/171018/171018.zarr')
store = zarr.DirectoryStore(
    "/home/jqs1/scratch/fidelity/171214/transcriptionerror_timelapse.zarr"
)
root_group = zarr.open_group(store=store)

In [None]:
channel_to_color = {
    "BF": "#ffffff",
    "MCHERRY": "#e22400",
    "GFP": "#76ba40",
    "CY5": "#e292fe",
    "BFP": "#3a87fd",
    "YFP": "#f5eb00",
}

# Compute

In [None]:
def display_plot_browser(plots, stream=None):
    to_display = {}
    if stream is None:
        initial_plots = plots
    else:
        initial_plots = plots(**stream.contents)
    browser = plot_browser(initial_plots, to_display)
    display(browser)
    display_plot_browser_contents(plots, to_display, stream=stream)
    return browser


def display_plot_browser_contents(plots, to_display, stream=None):
    for output, (obj, path) in to_display.items():
        if stream is None:
            with output:
                display(obj)
        else:
            # print(path, obj.__class__)
            if isinstance(
                obj, hv.core.dimension.ViewableElement
            ):  # TODO: is this the right comparison?
                # preprocess: make into DynamicMaps, regrid Images, etc.
                if isinstance(obj, (hv.Image, hv.RGB, hv.Raster)):
                    continue

                def callback(p, **kwargs):
                    plot = recursive_getattr(plots(**kwargs), p)
                    return plot.map(
                        regrid,
                        lambda obj: isinstance(obj, (hv.Image, hv.RGB, hv.Raster)),
                    ).collate()

                dmap = hv.DynamicMap(
                    partial(callback, path), streams=[stream]
                )  # .collate()
                # dmap = obj.map(regrid, lambda obj: isinstance(obj, (hv.Image, hv.RGB, hv.Raster))).collate()
                # dmap = regrid(dmap)
                with output:
                    display(dmap)
            else:
                # normal python type
                def callback(o, p, **kwargs):
                    with o:
                        clear_output()
                        display(recursive_getattr(plots(**kwargs), p))
                        # display((p, kwargs, np.random.random(), recursive_getattr(plots(**kwargs), p)))
                        # display((p, np.random.random()))

                callback(output, path, **stream.contents)
                # we need to do the partial trick, or else output is only bound to the last output of the for loop
                stream.add_subscriber(partial(callback, output, path))


def plot_browser(plots, to_display=None, path=()):
    if not isinstance(plots, Mapping):
        raise NotImplementedError
    children = []
    singleton_children = [k for k, v in plots.items() if not isinstance(v, Mapping)]
    for k in singleton_children:
        label = widgets.HTML("<b>{}</b>".format(k))
        output = widgets.Output()
        to_display[output] = (plots[k], path + (k,))
        child = widgets.HBox([label, output])
        children.append(child)
    nested_children = [k for k in plots.keys() if k not in singleton_children]
    if nested_children:
        accordion_children = [
            plot_browser(plots[k], to_display=to_display, path=(path + (k,)))
            for k in nested_children
        ]
        accordion = widgets.Accordion(children=accordion_children)
        for i, k in enumerate(nested_children):
            accordion.set_title(i, k)
        children.append(accordion)
    return widgets.VBox(children)


frame_stream = FrameStream()
display(frame_browser(root_group["quantized"], frame_stream))
display_plot_browser(lambda t, v: diag_pos[v] if v in diag_pos else None, frame_stream)
# s2 = FrameStream()
# display(frame_browser(root_group['quantized'], s2))
# display_plot_browser(lambda t, v: {'a': {'b': {'c': 2}}, 'x': 1} if v in diag_pos else None, s2);

In [None]:
frame_stream = big_image_viewer(root_group["quantized"])

In [None]:
%%time
diag_pos = tree()
trench_points_pos = {
    pos: fail_silently(
        lambda: get_trenches(
            root_group["raw"][str(pos)][0, 30], diagnostics=diag_pos[pos]
        )
    )
    for pos in range(20)
}

In [None]:
# a = diag_pos[0]['label_1']['trench_ends']['image_with_trenches']
a = diag_pos[0]["label_1"]["trench_rotation"]["hough_1"]["abs_diff_h"]

In [None]:
a.key

In [None]:
a.keys()

In [None]:
?a.map

In [None]:
isinstance(a.VLine, hv.Element)

In [None]:
a.keys()

In [None]:
a.traverse(lambda x: x)

In [None]:
def cb(pos):
    plot = diag_pos[pos]["label_1"]["trench_rotation"]["hough_1"]["abs_diff_h"]
    return plot


d = hv.DynamicMap(cb, kdims=["pos"]).redim.values(pos=range(10))

In [None]:
d

In [None]:
d.collate()

In [None]:
a.collate().map(datashade, hv.Curve)

In [None]:
def cb1(pos):
    plot = diag_pos[pos]["label_1"]["trench_rotation"]["hough_1"]["abs_diff_h"].Curve
    return plot


d1 = hv.DynamicMap(cb1, kdims=["pos"]).redim.values(pos=range(10))


def cb2(pos):
    plot = diag_pos[pos]["label_1"]["trench_rotation"]["hough_1"]["abs_diff_h"].VLine
    return plot


d2 = hv.DynamicMap(cb2, kdims=["pos"]).redim.values(pos=range(10))
e = datashade(d1) * d2
e

In [None]:
from holoviews.plotting.util import split_dmap_overlay

In [None]:
hvpu.split_dmap_overlay

In [None]:
# d.map(datashade, hv.Curve)
# datashade(d)
d

In [None]:
d1 = hv.DynamicMap(
    lambda pos: diag_pos[pos]["label_1"]["trench_rotation"]["hough_1"][
        "abs_diff_h"
    ].Curve,
    kdims=["pos"],
).redim.values(pos=range(10))
d2 = hv.DynamicMap(
    lambda pos: diag_pos[pos]["label_1"]["trench_rotation"]["hough_1"][
        "abs_diff_h"
    ].VLine,
    kdims=["pos"],
).redim.values(pos=range(10))

In [None]:
z = d1 + d2
z.collate()
z

In [None]:
d["Curve.I"]

In [None]:
b = a.collate()

In [None]:
a.traverse(lambda x: x.path)

In [None]:
a

In [None]:
%%time
sharpness = np.array(
    [image_sharpness(frame_series[i]) for i in range(frame_series.shape[0])]
)

# Image viewers

In [None]:
def select(xs, mask):
    if len(xs) != len(mask):
        raise ValueError("mask length does not match")
    return [xs[i] for i in range(len(xs)) if mask[i]]


def composite_channels(imgs, hexcolors, scale=True):
    colors = [hex2color(hexcolor) for hexcolor in hexcolors]
    return _composite_channels(imgs, colors, scale=scale)


def _composite_channels(channel_imgs, colors, scale=True):
    if len(channel_imgs) != len(colors):
        raise ValueError("expecting equal numbers of channels and colors")
    num_channels = len(channel_imgs)
    if scale:
        scaled_imgs = [
            channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
            for i in range(num_channels)
        ]
        for scaled_img in scaled_imgs:
            np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
    else:
        scaled_imgs = channel_imgs
    imgs_to_combine = [
        scaled_imgs[i] * np.array(colors[i]) for i in range(num_channels)
    ]
    if not len(imgs_to_combine):
        imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
    img = imgs_to_combine[0]
    for img2 in imgs_to_combine[1:]:
        img = 1 - (1 - img) * (1 - img2)
    return img


def multichannel_selector(frames):
    channels = frames.attrs["metadata"]["channels"]
    num_channels = len(channels)
    # colors = [hex2color(channel_colors[channel]) for channel in channels]
    channel_boxes = []
    channel_widgets = []
    channel_enabled = [True] * num_channels
    channel_colors = [channel_to_color[channel] for channel in channels]
    for i, channel in enumerate(channels):
        solo_button = widgets.Button(
            description="S", layout=widgets.Layout(width="10%")
        )
        enabled_button = widgets.ToggleButton(
            description=channel, value=channel_enabled[i]
        )
        solo_button._button_to_enable = enabled_button
        color_picker = widgets.ColorPicker(concise=True, value=channel_colors[i])
        channel_box = widgets.HBox([solo_button, enabled_button, color_picker])
        channel_widgets.append([solo_button, enabled_button, color_picker, channel_box])
    solo_buttons, enabled_buttons, color_pickers, channel_boxes = zip(*channel_widgets)
    channels_box = widgets.VBox(channel_boxes)
    DisplaySettings = Stream.define(
        "DisplaySettings",
        channel_enabled=channel_enabled,
        channel_colors=channel_colors,
    )
    display_settings_stream = DisplaySettings()

    def update_enabled_channels(change):
        channel_enabled = [button.value for button in enabled_buttons]
        display_settings_stream.event(channel_enabled=channel_enabled)

    def update_solo(solo_button):
        if (
            solo_button._button_to_enable.value
            and sum([b.value for b in enabled_buttons]) == 1
        ):
            for enabled_button in enabled_buttons:
                enabled_button.value = True
        else:
            for enabled_button in enabled_buttons:
                enabled_button.value = enabled_button == solo_button._button_to_enable
        # update_enabled_channels(None)

    for solo_button in solo_buttons:
        solo_button.on_click(update_solo)
    for enabled_button in enabled_buttons:
        enabled_button.observe(update_enabled_channels, names="value")

    def update_channel_colors(change):
        channel_colors = [color_picker.value for color_picker in color_pickers]
        display_settings_stream.event(channel_colors=channel_colors)

    for color_picker in color_pickers:
        color_picker.observe(update_channel_colors, names="value")
    return channels_box, display_settings_stream

In [None]:
FrameStream = Stream.define("Frame", t=0, v=0)


def timepoints_browser(frames, frame_stream):
    num_timepoints = len(frames.attrs["metadata"]["frames"])
    # play_buttons = widgets.Play(interval=10, min=0, max=num_timepoints, step=1)
    back_step_button = widgets.Button(
        description="<", layout=widgets.Layout(width="10%")
    )
    forward_step_button = widgets.Button(
        description=">", layout=widgets.Layout(width="10%")
    )
    t_slider = widgets.IntSlider(
        label="t", min=0, max=num_timepoints, step=1, value=0, continuous_update=False
    )
    slider_box = widgets.HBox([back_step_button, t_slider, forward_step_button])
    t_slider.observe(lambda change: frame_stream.event(t=change["new"]), names="value")
    return slider_box


def frame_browser(frames, frame_stream):
    num_timepoints = len(frames.attrs["metadata"]["frames"])
    num_fovs = len(frames.attrs["metadata"]["fields_of_view"])
    t_slider = widgets.IntSlider(
        label="t", min=0, max=num_timepoints, step=1, value=0, continuous_update=False
    )
    v_slider = widgets.IntSlider(
        label="v", min=0, max=num_fovs, step=1, value=0, continuous_update=False
    )
    slider_box = widgets.VBox([v_slider, t_slider])
    t_slider.observe(lambda change: frame_stream.event(t=change["new"]), names="value")
    v_slider.observe(lambda change: frame_stream.event(v=change["new"]), names="value")
    return slider_box


def big_image_viewer(positions):
    num_channels = positions[0].shape[0]  # TODO
    frame_stream = FrameStream()
    slider_box = frame_browser(positions, frame_stream)
    channels_box, display_settings_stream = multichannel_selector(positions)

    def image_callback(t, v, channel_enabled, channel_colors):
        pos_stack = positions[str(v)]
        channel_imgs = [
            pos_stack[c, t, :, :] for c in range(num_channels) if channel_enabled[c]
        ]
        img = composite_channels(channel_imgs, select(channel_colors, channel_enabled))
        viewer = RevRGB(img)
        return viewer

    image = hv.DynamicMap(
        image_callback, streams=[frame_stream, display_settings_stream]
    )
    image = regrid(image)
    image = image.opts(plot={"width": 500, "height": 500})
    output = widgets.Output()
    box = widgets.VBox([widgets.HBox([channels_box, slider_box]), output])
    display(box)
    with output:
        display(image)
    return frame_stream


# big_image_viewer(root_group['quantized'])

In [None]:
def overlay_image_viewer(frames):
    frame_stream = FrameStream()
    slider_box = frame_browser(frames, frame_stream)
    channels_box, display_settings_stream = multichannel_selector(frames)

    def image_callback(t, v, channel_enabled, channel_colors):
        channel_imgs = [
            frames[v, c, t, :, :] for c in range(frames.shape[1]) if channel_enabled[c]
        ]
        img = composite_channels(channel_imgs, select(channel_colors, channel_enabled))
        viewer = hv.RGB(img[::-1], bounds=(0, 0, img.shape[1], img.shape[0]))
        return viewer

    image = hv.DynamicMap(
        image_callback, streams=[frame_stream, display_settings_stream]
    )
    image = regrid(image)
    image = image.opts(plot={"width": 500, "height": 500})
    output = widgets.Output()
    box = widgets.VBox([widgets.HBox([channels_box, slider_box]), output])
    display(box)
    with output:
        display(image)
    return None


overlay_image_viewer(frames_z)

In [None]:
%%opts Image (cmap='viridis')
def scrubber_callback(t, y):
    t = int(t)
    v_line = hv.VLine(t + 0.5)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    t_label = hv.Text(t + kymo.shape[1] / 20, kymo.shape[0] / 20, "t={:}".format(t))(
        style={"color": "white", "text_alpha": 0.8}
    )
    h_line = hv.HLine(y)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    return v_line * t_label * h_line


def cross_section_callback(t, y):
    y = int(y)
    h_line = hv.HLine(y)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    return h_line


def sharpness_scrubber_callback(t):
    t = int(t)
    v_line = hv.VLine(t + 0.5)(style={"color": "black", "alpha": 0.3, "line_width": 2})
    return v_line


def trench_thumbnail_callback(img_series, t):
    t = int(t)
    img = get_trench_thumbnail(img_series[t], trench_points, trench_idx)
    # TODO: don't know why calling hv.Image.opts(plot={'invert_axes': True}) doesn't work
    # also switched HLine to VLine in cross_section, above
    thumb = hv.Image(
        img[::-1], bounds=(0, 0, img.shape[1], img.shape[0]), kdims=["x2", "y"]
    )
    # return thumb
    # return thumb.opts(plot={'invert_axes': False, 'width': max_thumbnail_width})#, 'xaxis': None, 'yaxis': None})
    return thumb.opts(
        plot={"invert_axes": False, "invert_yaxis": True, "width": 200}
    )  # , 'xaxis': None, 'yaxis': None})
    # return thumb.opts(plot={'invert_axes': False, 'invert_yaxis': True})


pointerx = hv.streams.PointerX(x=0).rename(x="t")
pointery = hv.streams.PointerY(y=0)

# BOUNDS: (left, bottom, top, right)
trench_idx = 15
img_series = frame_series_k1
bbox_ul, bbox_lr = get_trench_bbox(
    trench_points, trench_idx, *get_img_limits(img_series[0])
)
kymo = extract_kymograph(
    img_series, trench_points[0][trench_idx], trench_points[1][trench_idx]
)
kymo_img = hv.Image(kymo[::-1], bounds=(0, 0, kymo.shape[1], kymo.shape[0])).opts(
    plot={"width": 700, "height": 300, "invert_yaxis": True}
)
# kymo_img = hv.Raster(kymo).opts(plot={'width': 700, 'height': 300, 'yaxis': 'left'})
scrubber_line = hv.DynamicMap(scrubber_callback, streams=[pointerx, pointery])
cross_section_line = hv.DynamicMap(cross_section_callback, streams=[pointerx, pointery])
trench_thumbnail_img = hv.DynamicMap(
    partial(trench_thumbnail_callback, img_series), streams=[pointerx]
)  # .opts(plot={'invert_axes': True})
sharpness_plot = hv.Curve(sharpness, kdims=["x"], vdims=["s"]).opts(
    plot={"width": 700, "height": 100}
)
sharpness_scrubber = hv.DynamicMap(sharpness_scrubber_callback, streams=[pointerx])
# pointery.source = kymo_img
# SEE: http://holoviews.org/reference/elements/bokeh/Distribution.html
# kymo_img * scrubber_line << (trench_thumbnail_img * cross_section_line) << (sharpness_plot * sharpness_scrubber)

KymoOverlayStream = Stream.define("KymoOverlayStream", overlay_enabled=True)
kymo_overlay_stream = KymoOverlayStream()
kymo_overlay_button = widgets.ToggleButton(description="Overlay", value=True)


def update_kymo_overlay(change):
    kymo_overlay_stream.event(overlay_enabled=change["new"])


kymo_overlay_button.observe(update_kymo_overlay, names="value")


def get_thumbnail_overlay(t, overlay_enabled):
    # because of https://github.com/ioam/holoviews/issues/1388, overlay_enabled must be True initially
    x0 = trench_points[0][trench_idx]
    x1 = trench_points[1][trench_idx]
    trench_line = hv.Curve([x0 - bbox_ul, x1 - bbox_ul]).opts(
        style={"color": "white", "alpha": 0.5, "line_width": 1.5},
        plot={"yaxis": None, "shared_axes": True},
    )
    overlays = []
    if overlay_enabled:
        overlays.append(trench_line)
    else:
        pass  # overlays = []#line.opts(style={'alpha': 0.0, 'color': 'red'})
    return hv.Overlay(overlays)


thumbnail_overlay = hv.DynamicMap(
    get_thumbnail_overlay, streams=[pointerx, kymo_overlay_stream]
)  # .opts(plot={'invert_yaxis': True})

display(kymo_overlay_button)
(
    kymo_img * scrubber_line
    << (trench_thumbnail_img * thumbnail_overlay * cross_section_line)
    << (sharpness_plot * sharpness_scrubber)
)
# kymo_img * scrubber_line << (trench_thumbnail_img * cross_section_line) << (sharpness_plot * sharpness_scrubber)
# trench_thumbnail_img * thumbnail_overlay
# trench_thumbnail_img * thumbnail_overlay * cross_section_line

In [None]:
def trench_peaks_callback(t, y):
    t = int(t)
    overlays = [hv.Curve(kymo[:, t]).opts(plot={"width": 500})]
    overlays.extend(
        [
            hv.Points([(x, kymo[int(x), t])]).opts(style={"size": 6})
            for x in kymo_endpoints[t]
        ]
    )
    return hv.Overlay(overlays)


hv.DynamicMap(trench_peaks_callback, streams=[pointerx, pointery])

In [None]:
%%time


def find_kymograph_cell_endpoints(kymograph, thresh=0.2, min_dist=3):
    endpoints = []
    for t in range(kymograph.shape[1]):
        idxs = peakutils.indexes(kymograph[:, t], thres=thresh, min_dist=min_dist)
        xs = idxs
        # xs = peakutils.interpolate(np.arange(kymograph.shape[0]), kymograph[:,t], ind=idxs)
        endpoints.append(xs)
    return endpoints


kymo_endpoints = find_kymograph_cell_endpoints(kymo, thresh=0.2, min_dist=5)

In [None]:
for t in range(kymo.shape[1] // 5):
    plt.plot(kymo[220:, t])

# Old

In [None]:
# %%opts Image (cmap='viridis')
img_series = frame_series
kymo = extract_kymograph(
    img_series, trench_points[0][trench_idx], trench_points[1][trench_idx]
)
kymo_img = hv.Image(kymo, bounds=(0, 0, kymo.shape[1], kymo.shape[0])).opts(
    plot={"width": 700, "height": 300}
)
trench_thumbnail_img = hv.DynamicMap(
    partial(trench_thumbnail, img_series), streams=[pointerx]
)
(
    kymo_img.opts(plot={"invert_axes": True}) * scrubber_line
    + trench_thumbnail_img * cross_section_line
)

In [None]:
# handle out-of-bounds/negative t/x values
# focus quality score
# show focus quality as a bar above kymograph
# accurate horizontal line position on thumbnail
# draw cross-section line through thumbnail
# synchronize zoom/pan of multiple kymograph viewers
# 3-up thumbnail viewer with crosshairs/endpoint correspondences (links) synchronized with kymograph viewer
# clicking through a track advances 3-up by one frame
# if you press END it finishes track and rewinds to the first unfinished track
# automatic tracking: always identify bottom-most endpoint?
# can we work out all other correspondences from this? (look at t=14)