In [None]:
import numpy as np
import pandas as pd
import holoviews as hv
import datashader as ds
from holoviews.operation.datashader import (
    aggregate,
    datashade,
    dynspread,
    shade,
    regrid,
)
from holoviews.operation import decimate
from holoviews.streams import Stream, param
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display, clear_output
import nd2reader
from numcodecs import Blosc, Delta
import zarr
import skimage
import skimage.morphology
import sklearn
from sklearn.cluster import DBSCAN

# from sklearn import metrics
# from sklearn.datasets.samples_generator import make_blobs
from sklearn.preprocessing import StandardScaler
import peakutils
import scipy.stats
import scipy.interpolate
from holoborodko_diff import holo_diff
import functools
from functools import partial
from itertools import zip_longest
from collections import Counter, defaultdict
from bokeh.models import WheelZoomTool
from bokeh.io import push_notebook, show, output_notebook

# from bokeh.layouts import row
# from bokeh.plotting import figure
from tqdm import tnrange, tqdm_notebook

In [None]:
%load_ext line_profiler
hv.notebook_extension("bokeh")
renderer = hv.renderer("bokeh")
%matplotlib inline

In [None]:
frames_z = zarr.open_array("/home/jqs1/scratch/fidelity/test/171018.zarr", mode="r")

In [None]:
# frames = nd2reader.ND2Reader('/home/jqs1/scratch/fidelity/171018/20171018_TrxnError_ID.nd2')

In [None]:
channel_colors = {
    "BF": "#ffffff",
    "MCHERRY": "#e22400",
    "GFP": "#76ba40",
    "CY5": "#e292fe",
    "BFP": "#3a87fd",
    "YFP": "#f5eb00",
}

In [None]:
%%output size=250

channels = frames_z.attrs["metadata"]["channels"]
n_channels = len(channels)
colors = [hex2color(channel_colors[channel]) for channel in channels]
num_timepoints = len(frames_z.attrs["metadata"]["frames"])
num_fovs = len(frames_z.attrs["metadata"]["fields_of_view"])

channel_boxes = []
channel_widgets = []
for channel in channels:
    solo_button = widgets.Button(description="S", layout=widgets.Layout(width="10%"))
    enabled_button = widgets.ToggleButton(description=channel, value=True)
    solo_button._button_to_enable = enabled_button
    color_picker = widgets.ColorPicker(concise=True, value=channel_colors[channel])
    channel_box = widgets.HBox([solo_button, enabled_button, color_picker])
    channel_widgets.append([solo_button, enabled_button, color_picker, channel_box])
solo_buttons, enabled_buttons, color_pickers, channel_boxes = zip(*channel_widgets)
channels_box = widgets.VBox(channel_boxes)
t_slider = widgets.IntSlider(
    label="t", min=0, max=num_timepoints, step=1, value=0, continuous_update=False
)
v_slider = widgets.IntSlider(
    min=0, max=num_fovs, step=1, value=0, continuous_update=False
)
slider_box = widgets.VBox([v_slider, t_slider])
control_box = widgets.HBox([channels_box, slider_box])
output = widgets.Output()
main_box = widgets.VBox([control_box, output])
display(main_box)

max_val = 2**14

Frame = Stream.define("Frame", t=0, v=0)
frame = Frame()
DisplaySettings = Stream.define(
    "DisplaySettings", channel_enabled=np.array([True] * n_channels)
)
display_settings = DisplaySettings()


def composite_image(t, v, channel_enabled):
    # def composite_image(t, v):
    # channel_enabled = [True] * n_channels
    # channel_imgs = [frames.get_frame_2D(c=i, t=t, v=v) for i in range(n_channels)]
    channel_imgs = [frames_z[v, c, t, :, :] for c in range(n_channels)]
    scaled_imgs = [
        channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
        for i in range(n_channels)
    ]
    for scaled_img in scaled_imgs:
        np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
    colored_imgs = [scaled_imgs[i] * np.array(colors[i]) for i in range(n_channels)]
    imgs_to_combine = [colored_imgs[i] for i in range(n_channels) if channel_enabled[i]]
    if not len(imgs_to_combine):
        imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
    img = imgs_to_combine[0]
    for img2 in imgs_to_combine[1:]:
        img = 1 - (1 - img) * (1 - img2)
    return hv.RGB(img, bounds=(-1, -1, 1, 1))  # .opts(plot={'size': 250}, tools=[''])


t_slider.observe(lambda change: frame.event(t=change["new"]), names="value")
v_slider.observe(lambda change: frame.event(v=change["new"]), names="value")


def update_enabled_channels(change):
    channel_enabled = np.array([button.value for button in enabled_buttons])
    display_settings.event(channel_enabled=channel_enabled)


def update_solo(solo_button):
    if (
        solo_button._button_to_enable.value
        and sum([b.value for b in enabled_buttons]) == 1
    ):
        for enabled_button in enabled_buttons:
            enabled_button.value = True
    else:
        for enabled_button in enabled_buttons:
            enabled_button.value = enabled_button == solo_button._button_to_enable
    # update_enabled_channels(None)


for solo_button in solo_buttons:
    solo_button.on_click(update_solo)

for enabled_button in enabled_buttons:
    enabled_button.observe(update_enabled_channels, names="value")
# for color_picker in color_pickers:
#    color_picker.observe(update_image, names='value')

# hv.DynamicMap(composite_image, kdims=['t', 'v', 'channel_enabled']).select(t=0,v=0,channel_enabled=np.array([True,False,False,False,False]))
image_viewer = hv.DynamicMap(composite_image, streams=[frame, display_settings])
regrid(image_viewer)

# Trench detection

In [None]:
def _standardize_cluster_labels(X, fit):
    mean = defaultdict(lambda: 0)
    count = defaultdict(lambda: 0)
    for i in range(len(fit.labels_)):
        mean[fit.labels_[i]] += X[i]
        count[fit.labels_[i]] += 1
    for k, v in mean.items():
        mean[k] = v / count[k]
    label_mapping = dict(zip(mean.keys(), np.lexsort(list(mean.values()))))
    for i, old_label in enumerate(fit.labels_):
        fit.labels_[i] = label_mapping[old_label]


def cluster_binary_image(bin_img):
    X = np.array(np.where(bin_img)).T
    X2 = StandardScaler().fit_transform(X.astype(np.float32))
    fit = sklearn.cluster.MiniBatchKMeans(
        init="k-means++", n_clusters=2, n_init=10, max_no_improvement=10, verbose=0
    )
    fit.fit(X2)
    _standardize_cluster_labels(X, fit)
    return X, fit


def label_binary_image(bin_img):
    X, fit = cluster_binary_image(bin_img)
    label_img = np.zeros_like(bin_img, dtype=np.int8)  # TODO: fixing dtype
    for i in range(len(fit.labels_)):
        label_img[X[i, 0], X[i, 1]] = fit.labels_[i] + 1
    return label_img


def drop_rare_labels(labels):
    counter = Counter(labels)
    total = sum(counter)
    good_labels = []
    for label, count in counter.iteritems():
        print(count / total)
        if count / total > 0.01:
            good_labels.append(label)
    return good_labels

In [None]:
def get_img_limits(img):
    x_min = y_min = 0
    x_max, y_max = img.shape
    # TODO: what convention should we use, should max be inclusive??
    # x_max = img.shape[0] - 1
    # y_max = img.shape[1] - 1
    x_lim = (x_min, x_max)
    y_lim = (y_min, y_max)
    return x_lim, y_lim


def detect_rotation(bin_img):
    h, theta, d = skimage.transform.hough_line(bin_img)
    abs_diff_h = np.diff(h.astype(np.int32), axis=1).var(axis=0)
    theta_idx = abs_diff_h.argmax()
    angle1 = theta[theta_idx]
    h2, theta2, d2 = skimage.transform.hough_line(
        bin_img, theta=np.linspace(0.9 * angle1, 1.1 * angle1, 200)
    )
    abs_diff_h2 = np.diff(h2.astype(np.int32), axis=1).var(axis=0)
    theta_idx2 = abs_diff_h2.argmax()
    angle2 = theta2[theta_idx2]
    d_profile = h2[:, theta_idx2].astype(np.int32)
    freqs = np.abs(np.fft.fft(d_profile))
    peak_idxs = peakutils.indexes(d_profile, thres=0.4, min_dist=5)
    peaks = d2[peak_idxs]
    spacing = scipy.stats.mode(np.diff(peaks)).mode[0]
    return np.pi / 2 - angle2, peaks


def get_rough_spacing(dists):
    spacing = scipy.stats.mode(np.diff(dists).astype(int)).mode[0]
    return spacing


def point_linspace(anchor0, anchor1, num_points):
    for s in np.linspace(0, 1, num_points)[1:-1]:
        anchor = (1 - s) * anchor0 + s * anchor1
        yield anchor


def coords_along(x0, x1):
    length = int(np.sqrt(np.sum((x1 - x0) ** 2)))
    xs = np.linspace(x0[0], x1[0], length).astype(np.int_)[1:-1]
    ys = np.linspace(x0[1], x1[1], length).astype(np.int_)[1:-1]
    return xs, ys


def edge_point(x0, theta, x_lim, y_lim):
    x_min, x_max = x_lim
    y_min, y_max = y_lim
    theta = theta % (2 * np.pi)
    if 0 <= theta < np.pi / 2:
        corner_x, corner_y = x_min, y_max
    elif np.pi / 2 <= theta < np.pi:
        corner_x, corner_y = x_max, y_max
    elif np.pi <= theta < 3 / 2 * np.pi:
        corner_x, corner_y = x_max, y_min
    elif 3 / 2 * np.pi <= theta <= 2 * np.pi:
        corner_x, corner_y = x_min, y_min
    angle_to_corner = np.arctan2(corner_y - x0[1], x0[0] - corner_x) % (2 * np.pi)
    if (
        (theta >= angle_to_corner and 0 <= theta < np.pi / 2)
        or (theta < angle_to_corner and np.pi / 2 <= theta < np.pi)
        or (theta >= angle_to_corner and np.pi <= theta < 3 / 2 * np.pi)
        or (theta < angle_to_corner and 3 / 2 * np.pi <= theta < 2 * np.pi)
    ):
        # top/bottom
        x1 = np.array([x0[0] - (corner_y - x0[1]) / np.tan(theta), corner_y])
    else:
        # left/right
        x1 = np.array([corner_x, x0[1] - (corner_x - x0[0]) * np.tan(theta)])
    return x1

In [None]:
def line_array(
    anchors, theta, x_lim, y_lim, start=None, stop=None, bidirectional=False
):
    if bidirectional:
        line_array1 = line_array(
            anchors, theta, x_lim, y_lim, start=start, stop=stop, bidirectional=False
        )
        line_array2 = line_array(
            anchors,
            theta + np.pi,
            x_lim,
            y_lim,
            start=start,
            stop=stop,
            bidirectional=False,
        )
        for (x0, x1), (y0, y1) in zip(line_array1, line_array2):
            yield x0, x1, y1
        return
    if start is None:
        start = 0
    if stop is None:
        stop = 0
    if not stop >= start >= 0:
        raise ValueError("need stop >= start >= 0")
    theta = theta % (2 * np.pi)
    for anchor in anchors:
        x0 = anchor
        x1 = edge_point(x0, theta, x_lim, y_lim)
        max_length = np.sqrt(((x1 - x0) ** 2).sum())
        y0, y1 = x0, x1
        if start:
            y0 = min(start / max_length, 1) * (x1 - x0) + x0
        if stop:
            y1 = min(stop / max_length, 1) * (x1 - x0) + x0
        if not np.array_equal(y0, y1):
            yield y0, y1


def get_anchors(theta, x_lim, y_lim):
    x_min = np.array([x_lim[0], y_lim[0]])
    x_max = np.array([x_lim[1], y_lim[1]])
    x0 = x_min + (x_max - x_min) / 2
    anchor0 = edge_point(x0, theta, x_lim, y_lim)
    anchor1 = edge_point(x0, theta + np.pi, x_lim, y_lim)
    return anchor0, anchor1

In [None]:
def detect_trench_region(bin_img, theta):
    x_lim, y_lim = get_img_limits(bin_img)
    anchor0, anchor1 = get_anchors(theta, x_lim, y_lim)
    cross_sections = []
    anchors = list(point_linspace(anchor0, anchor1, 40))[3:-3]  # TODO: parameterize
    lines = list(
        line_array(anchors, np.pi / 2 + theta, x_lim, y_lim, bidirectional=True)
    )
    for x0, x1, x2 in lines:
        xs, ys = coords_along(x1, x2)
        cross_sections.append(bin_img[ys, xs])
    cross_section_vars = np.array([cs.var() for cs in cross_sections])
    idx = cross_section_vars.argmax()
    return anchors[idx]

In [None]:
# FROM: https://stackoverflow.com/questions/23815327/numpy-one-liner-for-combining-unequal-length-np-array-to-a-matrixor-2d-array
def stack_jagged(arys, fill=0):
    return np.array(list(zip_longest(*arys, fillvalue=fill))).T


def detect_periodic_peaks(signal):
    idxs = peakutils.indexes(signal, thres=0.2, min_dist=5)
    xs = peakutils.interpolate(np.arange(len(signal)), signal, ind=idxs)
    plt.figure(figsize=(16, 8))
    plt.plot()
    plt.plot(np.arange(len(signal)), signal)
    plt.scatter(xs, signal[xs.astype(np.int_)], c="r")
    dxs = np.diff(xs)
    period_min = np.percentile(dxs, 10)
    period_max = dxs.max()
    num_periods = 100
    periods = np.linspace(period_min, period_max, num_periods)
    # std = ((dxs + periods[:,np.newaxis]/2) % periods[:,np.newaxis]).std(axis=1)
    std = scipy.stats.iqr(((xs) % periods[:, np.newaxis]), axis=1) / periods
    period_idx = std.argmin()
    period = periods[period_idx]
    plt.figure()
    plt.plot(periods, std)
    plt.scatter([period], [std[period_idx]], c="r")
    periods2 = np.linspace(period * 0.98, period * 1.02, num_periods)
    std2 = scipy.stats.iqr(((xs) % periods2[:, np.newaxis]), axis=1) / periods2
    period_idx2 = std2.argmin()
    period2 = periods2[period_idx2]
    plt.figure()
    plt.plot(periods2, std2)
    plt.scatter([period2], [std2[period_idx2]], c="r")
    offsets = np.linspace(0, period2, num_periods)
    offset_idxs = (
        np.arange(0, len(signal) - period2, period2) + offsets[:, np.newaxis]
    ).astype(np.int_)
    objective = signal[offset_idxs].sum(axis=1)
    offset_idx = objective.argmax()
    offset = offsets[offset_idx]
    plt.figure()
    plt.plot(offsets, objective)
    plt.scatter([offset], [objective[offset_idx]], c="r")
    return period2, offset


def detect_trench_anchors(img, t0, theta):
    x_lim, y_lim = get_img_limits(img)
    x1 = edge_point(t0, theta - np.pi / 2, x_lim, y_lim)
    x2 = edge_point(t0, theta + np.pi / 2, x_lim, y_lim)
    xs, ys = coords_along(x1, x2)
    profile = img[ys, xs]
    period, offset = detect_periodic_peaks(profile)
    idxs = np.arange(offset, len(profile), period).astype(np.int_)
    plt.figure(figsize=(16, 8))
    plt.plot(profile)
    plt.scatter(idxs, profile[idxs], c="r")
    return np.vstack((xs[idxs], ys[idxs])).T


def _detect_trench_end(img, anchors, theta):
    x_lim, y_lim = get_img_limits(img)
    xss = []
    yss = []
    trench_profiles = []
    for anchor in anchors:
        x_end = edge_point(anchor, theta, x_lim, y_lim)
        xs, ys = coords_along(anchor, x_end)
        xss.append(xs)
        yss.append(ys)
        trench_profiles.append(img[ys, xs])
    plt.figure(figsize=(8, 8))
    for trench_profile in trench_profiles:
        plt.plot(trench_profile)
    stacked_profile = np.percentile(stack_jagged(trench_profiles), 80, axis=0)
    # cum_profile = np.cumsum(stacked_profile)
    # cum_profile /= cum_profile[-1]
    # end = np.where(cum_profile > 0.8)[0][0]
    stacked_profile_diff = holo_diff(1, stacked_profile)
    end = stacked_profile_diff.argmin()
    plt.figure(figsize=(8, 8))
    plt.plot(stacked_profile)
    plt.axvline(end, c="r")
    # ax2 = plt.gca().twinx()
    # ax2.plot(stacked_profile_diff, color='g')
    plt.figure(figsize=(8, 8))
    plt.plot(stacked_profile_diff, color="g")
    plt.axvline(end, c="r")
    end_points = []
    for xs, ys in zip(xss, yss):
        idx = end
        if len(xs) <= end:
            idx = -1
        end_points.append((xs[idx], ys[idx]))
    return np.array(end_points)


def detect_trench_ends(img, bin_img, anchors, theta):
    img_masked = np.where(
        skimage.morphology.binary_dilation(bin_img), img, np.percentile(img, 5)
    )
    top_points = _detect_trench_end(img_masked, anchors, theta)
    bottom_points = _detect_trench_end(img_masked, anchors, theta + np.pi)
    plt.figure(figsize=(12, 12))
    plt.imshow(img_masked)
    plt.scatter(*anchors.T, s=3, c="w")
    plt.scatter(*top_points.T, s=3, c="g")
    plt.scatter(*bottom_points.T, s=3, c="r")
    return top_points, bottom_points


def detect_trenches(img, bin_img, theta):
    t0 = detect_trench_region(bin_img, theta)
    trench_anchors = detect_trench_anchors(img, t0, theta)
    trench_points = detect_trench_ends(img, bin_img, trench_anchors, theta)
    return trench_points

# Segmentation

In [None]:
def _label_for_trenches(img_series):
    img = img_series.max(axis=0)
    # TODO: need rotation-invariant detrending
    img = img - np.percentile(img, 3, axis=1)[:, np.newaxis]
    img_thresh = img > skimage.filters.threshold_otsu(img)
    img_labels = label_binary_image(img_thresh)
    return img, img_labels


def get_image_series_trenches(img_series, label):
    img, img_labels = _label_for_trenches(img_series)
    plt.figure(figsize=(8, 8))
    plt.imshow(img_labels)
    theta, dists = detect_rotation(img_labels == label)
    trench_points = detect_trenches(img, img_labels == label, theta)
    return trench_points

In [None]:
frames_z.shape

In [None]:
%%time
frame_series = frames_z[0, 0, :, :, :]

In [None]:
%%time
trench_points = get_image_series_trenches(frame_series, 1)

In [None]:
def hessian_eigenvalues(img):
    I = skimage.filters.gaussian(img, 1.5)
    I_x = skimage.filters.sobel_h(I)
    I_y = skimage.filters.sobel_v(I)
    I_xx = skimage.filters.sobel_h(I_x)
    I_xy = skimage.filters.sobel_v(I_x)
    I_yx = skimage.filters.sobel_h(I_y)
    I_yy = skimage.filters.sobel_v(I_y)
    kappa_1 = (I_xx + I_yy) / 2
    kappa_2 = (np.sqrt((I_xx + I_yy) ** 2 - 4 * (I_xx * I_yy - I_xy * I_yx))) / 2
    k1 = kappa_1 + kappa_2
    k2 = kappa_1 - kappa_2
    k1[np.isnan(k1)] = 0
    k2[np.isnan(k2)] = 0
    return k1, k2

In [None]:
%%time
frame_series_k1 = np.zeros_like(frame_series, dtype=np.float32)  # TODO: fixed dtype
for i in range(frame_series.shape[0]):
    frame_series_k1[i, :, :] = hessian_eigenvalues(frame_series[i])[0]

In [None]:
def extract_kymograph(img_series, x0, x1):
    num_timepoints = img_series.shape[0]
    xs, ys = coords_along(x0, x1)
    kymo = np.zeros((len(xs), num_timepoints))
    for t in range(num_timepoints):
        kymo[:, t] = img_series[t, ys, xs][::-1]
    return kymo


def get_image_series_segmentation(img_series, x0, x1):
    pass  # list of trenches, for each trench, a list of cell masks


def map_over_segmentation(img_series, cell_seg, func):
    pass

In [None]:
def bounding_box(points):
    upper_left_x = min(point[0] for point in points)
    upper_left_y = min(point[1] for point in points)
    lower_right_x = max(point[0] for point in points)
    lower_right_y = max(point[1] for point in points)
    return (
        np.array([upper_left_x, upper_left_y]),
        np.array([lower_right_x, lower_right_y]),
    )


def crop_point(x, x_lim, y_lim):
    return np.clip(x, *np.vstack([[x_lim], [y_lim]]).T)


def get_trench_thumbnail(img, trench_points, trench_idx):
    x_lim, y_lim = get_img_limits(img)
    ul, lr = get_trench_bbox(trench_points, trench_idx, x_lim, y_lim)
    return img[ul[1] : lr[1], ul[0] : lr[0]]


def get_trench_bbox(trench_points, trench_idx, x_lim, y_lim):
    # trench_points[0][trench_idx], trench_points[1][trench_idx]
    num_trenches = min(len(trench_points[0]), len(trench_points[1]))
    if not 0 <= trench_idx < num_trenches:
        raise ValueError("trench index out of bounds")
    points = [trench_points[i][trench_idx] for i in (0, 1)]
    if trench_idx == 0:
        x0_prev = 2 * trench_points[0][0] - trench_points[0][1]
        x1_prev = 2 * trench_points[1][0] - trench_points[1][1]
        x0_prev = crop_point(x0_prev, x_lim, y_lim)
        x1_prev = crop_point(x1_prev, x_lim, y_lim)
        points += [x0_prev, x1_prev]
    else:
        points += [trench_points[i][trench_idx - 1] for i in (0, 1)]
    if trench_idx + 1 == num_trenches:
        x0_next = 2 * trench_points[0][-1] - trench_points[0][-2]
        x1_next = 2 * trench_points[1][-1] - trench_points[1][-2]
        x0_next = crop_point(x0_next, x_lim, y_lim)
        x1_next = crop_point(x1_next, x_lim, y_lim)
        points += [x0_next, x1_next]
    else:
        points += [trench_points[i][trench_idx + 1] for i in (0, 1)]
    return bounding_box(points)

In [None]:
def image_sharpness(img):
    # FROM: https://stackoverflow.com/questions/7765810/is-there-a-way-to-detect-if-an-image-is-blurry/7767755#7767755
    img_blurred = skimage.filters.gaussian(img, 1)
    img_lofg = skimage.filters.laplace(img_blurred)
    return np.percentile(img_lofg, 99.9)

In [None]:
%%time
sharpness = np.array(
    [image_sharpness(frame_series[i]) for i in range(frame_series.shape[0])]
)

In [None]:
FrameStream = Stream.define("Frame", t=0, v=0)

In [None]:
def frame_browser(frames, frame_stream):
    num_timepoints = len(frames_z.attrs["metadata"]["frames"])
    num_fovs = len(frames_z.attrs["metadata"]["fields_of_view"])
    t_slider = widgets.IntSlider(
        label="t", min=0, max=num_timepoints, step=1, value=0, continuous_update=False
    )
    v_slider = widgets.IntSlider(
        min=0, max=num_fovs, step=1, value=0, continuous_update=False
    )
    slider_box = widgets.VBox([v_slider, t_slider])
    t_slider.observe(lambda change: frame_stream.event(t=change["new"]), names="value")
    v_slider.observe(lambda change: frame_stream.event(v=change["new"]), names="value")
    return slider_box

In [None]:
display?

In [None]:
hv.output?

In [None]:
def image_viewer(frames, frame_stream):
    channels = frames.attrs["metadata"]["channels"]
    n_channels = len(channels)
    colors = [hex2color(channel_colors[channel]) for channel in channels]
    channel_boxes = []
    channel_widgets = []
    for channel in channels:
        solo_button = widgets.Button(
            description="S", layout=widgets.Layout(width="10%")
        )
        enabled_button = widgets.ToggleButton(description=channel, value=True)
        solo_button._button_to_enable = enabled_button
        color_picker = widgets.ColorPicker(concise=True, value=channel_colors[channel])
        channel_box = widgets.HBox([solo_button, enabled_button, color_picker])
        channel_widgets.append([solo_button, enabled_button, color_picker, channel_box])
    solo_buttons, enabled_buttons, color_pickers, channel_boxes = zip(*channel_widgets)
    channels_box = widgets.VBox(channel_boxes)
    max_val = 2**14  # TODO: infer
    DisplaySettings = Stream.define(
        "DisplaySettings", channel_enabled=np.array([True] * n_channels)
    )
    display_settings_stream = DisplaySettings()

    def composite_image(t, v, channel_enabled):
        channel_imgs = [frames[v, c, t, :, :] for c in range(n_channels)]
        scaled_imgs = [
            channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
            for i in range(n_channels)
        ]
        for scaled_img in scaled_imgs:
            np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
        colored_imgs = [scaled_imgs[i] * np.array(colors[i]) for i in range(n_channels)]
        imgs_to_combine = [
            colored_imgs[i] for i in range(n_channels) if channel_enabled[i]
        ]
        if not len(imgs_to_combine):
            imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
        img = imgs_to_combine[0]
        for img2 in imgs_to_combine[1:]:
            img = 1 - (1 - img) * (1 - img2)
        return hv.RGB(
            img, bounds=(-1, -1, 1, 1)
        )  # .opts(plot={'size': 250}, tools=[''])

    def update_enabled_channels(change):
        channel_enabled = np.array([button.value for button in enabled_buttons])
        display_settings_stream.event(channel_enabled=channel_enabled)

    def update_solo(solo_button):
        if (
            solo_button._button_to_enable.value
            and sum([b.value for b in enabled_buttons]) == 1
        ):
            for enabled_button in enabled_buttons:
                enabled_button.value = True
        else:
            for enabled_button in enabled_buttons:
                enabled_button.value = enabled_button == solo_button._button_to_enable
        # update_enabled_channels(None)

    for solo_button in solo_buttons:
        solo_button.on_click(update_solo)

    for enabled_button in enabled_buttons:
        enabled_button.observe(update_enabled_channels, names="value")
    # for color_picker in color_pickers:
    #    color_picker.observe(update_image, names='value')

    # hv.DynamicMap(composite_image, kdims=['t', 'v', 'channel_enabled']).select(t=0,v=0,channel_enabled=np.array([True,False,False,False,False]))
    image = hv.DynamicMap(
        composite_image, streams=[frame_stream, display_settings_stream]
    )
    regridded_image = regrid(image)
    return channels_box, regridded_image


def big_image_viewer(frames):
    frame_stream = FrameStream()
    slider_box = frame_browser(frames, frame_stream)
    channels_box, image = image_viewer(frames, frame_stream)
    image = image.opts(plot={"width": 500, "height": 500})
    output = widgets.Output()
    box = widgets.VBox([widgets.HBox([channels_box, slider_box]), output])
    display(box)
    with output:
        display(image)
    return None  # box


def mini_image_viewer(frames):
    channels_box, image = image_viewer(frames)
    output = widgets.Output()
    box = widgets.VBox([channels_box, output])
    display(box)
    with output:
        display(image)
    return None  # box


big_image_viewer(frames_z)

In [None]:
%%output size=250
mini_image_viewer(frames_z)

In [None]:
%%output size=250
mini_image_viewer(frames_z)

In [None]:
%%opts Image (cmap='viridis')
def scrubber_callback(t, y):
    t = int(t)
    v_line = hv.VLine(t + 0.5)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    t_label = hv.Text(t + kymo.shape[1] / 20, kymo.shape[0] / 20, "t={:}".format(t))(
        style={"color": "white", "text_alpha": 0.8}
    )
    h_line = hv.HLine(y)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    return v_line * t_label * h_line


def cross_section_callback(t, y):
    y = int(y)
    h_line = hv.HLine(y)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    return h_line


def sharpness_scrubber_callback(t):
    t = int(t)
    v_line = hv.VLine(t + 0.5)(style={"color": "black", "alpha": 0.3, "line_width": 2})
    return v_line


def trench_thumbnail_callback(img_series, t):
    t = int(t)
    img = get_trench_thumbnail(img_series[t], trench_points, trench_idx)
    # TODO: don't know why calling hv.Image.opts(plot={'invert_axes': True}) doesn't work
    # also switched HLine to VLine in cross_section, above
    thumb = hv.Image(
        img[::-1], bounds=(0, 0, img.shape[1], img.shape[0]), kdims=["x2", "y"]
    )
    # return thumb
    # return thumb.opts(plot={'invert_axes': False, 'width': max_thumbnail_width})#, 'xaxis': None, 'yaxis': None})
    return thumb.opts(
        plot={"invert_axes": False, "invert_yaxis": True, "width": 200}
    )  # , 'xaxis': None, 'yaxis': None})
    # return thumb.opts(plot={'invert_axes': False, 'invert_yaxis': True})


pointerx = hv.streams.PointerX(x=0).rename(x="t")
pointery = hv.streams.PointerY(y=0)

# BOUNDS: (left, bottom, top, right)
trench_idx = 15
img_series = frame_series_k1
bbox_ul, bbox_lr = get_trench_bbox(
    trench_points, trench_idx, *get_img_limits(img_series[0])
)
kymo = extract_kymograph(
    img_series, trench_points[0][trench_idx], trench_points[1][trench_idx]
)
kymo_img = hv.Image(kymo[::-1], bounds=(0, 0, kymo.shape[1], kymo.shape[0])).opts(
    plot={"width": 700, "height": 300, "invert_yaxis": True}
)
# kymo_img = hv.Raster(kymo).opts(plot={'width': 700, 'height': 300, 'yaxis': 'left'})
scrubber_line = hv.DynamicMap(scrubber_callback, streams=[pointerx, pointery])
cross_section_line = hv.DynamicMap(cross_section_callback, streams=[pointerx, pointery])
trench_thumbnail_img = hv.DynamicMap(
    partial(trench_thumbnail_callback, img_series), streams=[pointerx]
)  # .opts(plot={'invert_axes': True})
sharpness_plot = hv.Curve(sharpness, kdims=["x"], vdims=["s"]).opts(
    plot={"width": 700, "height": 100}
)
sharpness_scrubber = hv.DynamicMap(sharpness_scrubber_callback, streams=[pointerx])
# pointery.source = kymo_img
# SEE: http://holoviews.org/reference/elements/bokeh/Distribution.html
# kymo_img * scrubber_line << (trench_thumbnail_img * cross_section_line) << (sharpness_plot * sharpness_scrubber)

KymoOverlayStream = Stream.define("KymoOverlayStream", overlay_enabled=True)
kymo_overlay_stream = KymoOverlayStream()
kymo_overlay_button = widgets.ToggleButton(description="Overlay", value=True)


def update_kymo_overlay(change):
    kymo_overlay_stream.event(overlay_enabled=change["new"])


kymo_overlay_button.observe(update_kymo_overlay, names="value")


def get_thumbnail_overlay(t, overlay_enabled):
    # because of https://github.com/ioam/holoviews/issues/1388, overlay_enabled must be True initially
    x0 = trench_points[0][trench_idx]
    x1 = trench_points[1][trench_idx]
    trench_line = hv.Curve([x0 - bbox_ul, x1 - bbox_ul]).opts(
        style={"color": "white", "alpha": 0.5, "line_width": 1.5},
        plot={"yaxis": None, "shared_axes": True},
    )
    overlays = []
    if overlay_enabled:
        overlays.append(trench_line)
    else:
        pass  # overlays = []#line.opts(style={'alpha': 0.0, 'color': 'red'})
    return hv.Overlay(overlays)


thumbnail_overlay = hv.DynamicMap(
    get_thumbnail_overlay, streams=[pointerx, kymo_overlay_stream]
)  # .opts(plot={'invert_yaxis': True})

display(kymo_overlay_button)
(
    kymo_img * scrubber_line
    << (trench_thumbnail_img * thumbnail_overlay * cross_section_line)
    << (sharpness_plot * sharpness_scrubber)
)
# kymo_img * scrubber_line << (trench_thumbnail_img * cross_section_line) << (sharpness_plot * sharpness_scrubber)
# trench_thumbnail_img * thumbnail_overlay
# trench_thumbnail_img * thumbnail_overlay * cross_section_line

In [None]:
def trench_peaks_callback(t, y):
    t = int(t)
    overlays = [hv.Curve(kymo[:, t]).opts(plot={"width": 500})]
    overlays.extend(
        [
            hv.Points([(x, kymo[int(x), t])]).opts(style={"size": 6})
            for x in kymo_endpoints[t]
        ]
    )
    return hv.Overlay(overlays)


hv.DynamicMap(trench_peaks_callback, streams=[pointerx, pointery])

In [None]:
%%time
def find_kymograph_cell_endpoints(kymograph, thresh=0.2, min_dist=3):
    endpoints = []
    for t in range(kymograph.shape[1]):
        idxs = peakutils.indexes(kymograph[:, t], thres=thresh, min_dist=min_dist)
        xs = idxs
        # xs = peakutils.interpolate(np.arange(kymograph.shape[0]), kymograph[:,t], ind=idxs)
        endpoints.append(xs)
    return endpoints


kymo_endpoints = find_kymograph_cell_endpoints(kymo, thresh=0.2, min_dist=5)

In [None]:
for t in range(kymo.shape[1] // 5):
    plt.plot(kymo[220:, t])

# Old

In [None]:
# %%opts Image (cmap='viridis')
img_series = frame_series
kymo = extract_kymograph(
    img_series, trench_points[0][trench_idx], trench_points[1][trench_idx]
)
kymo_img = hv.Image(kymo, bounds=(0, 0, kymo.shape[1], kymo.shape[0])).opts(
    plot={"width": 700, "height": 300}
)
trench_thumbnail_img = hv.DynamicMap(
    partial(trench_thumbnail, img_series), streams=[pointerx]
)
(
    kymo_img.opts(plot={"invert_axes": True}) * scrubber_line
    + trench_thumbnail_img * cross_section_line
)

In [None]:
# handle out-of-bounds/negative t/x values
# focus quality score
# show focus quality as a bar above kymograph
# accurate horizontal line position on thumbnail
# draw cross-section line through thumbnail
# synchronize zoom/pan of multiple kymograph viewers
# 3-up thumbnail viewer with crosshairs/endpoint correspondences (links) synchronized with kymograph viewer
# clicking through a track advances 3-up by one frame
# if you press END it finishes track and rewinds to the first unfinished track
# automatic tracking: always identify bottom-most endpoint?
# can we work out all other correspondences from this? (look at t=14)