In [None]:
import functools
from collections import Counter
from itertools import zip_longest

import datashader as ds
import holoviews as hv
import ipywidgets as widgets
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import peakutils
import scipy.interpolate
import scipy.stats
import skimage
import skimage.morphology
import sklearn
import zarr
from bokeh.io import output_notebook, push_notebook, show
from bokeh.models import WheelZoomTool
from holoborodko_diff import holo_diff
from holoviews.operation import decimate
from holoviews.operation.datashader import (
    aggregate,
    datashade,
    dynspread,
    regrid,
    shade,
)
from holoviews.streams import Stream, param
from IPython.display import clear_output, display
from ipywidgets import fixed, interact, interact_manual, interactive
from matplotlib.colors import hex2color
from numcodecs import Blosc, Delta
from sklearn.cluster import DBSCAN

# from sklearn import metrics
# from sklearn.datasets.samples_generator import make_blobs
from sklearn.preprocessing import StandardScaler

# from bokeh.layouts import row
# from bokeh.plotting import figure
from tqdm import tnrange, tqdm_notebook

In [None]:
%load_ext line_profiler
hv.notebook_extension("bokeh")
renderer = hv.renderer("bokeh")
%matplotlib inline

In [None]:
frames_z = zarr.open_array("/home/jqs1/scratch/fidelity/test/171018.zarr", mode="r")

In [None]:
# frames = nd2reader.ND2Reader('/home/jqs1/scratch/fidelity/171018/20171018_TrxnError_ID.nd2')

In [None]:
channel_colors = {
    "BF": "#ffffff",
    "MCHERRY": "#e22400",
    "GFP": "#76ba40",
    "CY5": "#e292fe",
    "BFP": "#3a87fd",
    "YFP": "#f5eb00",
}

In [None]:
%%output size=250

channels = frames_z.attrs["metadata"]["channels"]
n_channels = len(channels)
colors = [hex2color(channel_colors[channel]) for channel in channels]
num_timepoints = len(frames_z.attrs["metadata"]["frames"])
num_fovs = len(frames_z.attrs["metadata"]["fields_of_view"])

channel_boxes = []
channel_widgets = []
for channel in channels:
    solo_button = widgets.Button(description="S", layout=widgets.Layout(width="10%"))
    enabled_button = widgets.ToggleButton(description=channel, value=True)
    solo_button._button_to_enable = enabled_button
    color_picker = widgets.ColorPicker(concise=True, value=channel_colors[channel])
    channel_box = widgets.HBox([solo_button, enabled_button, color_picker])
    channel_widgets.append([solo_button, enabled_button, color_picker, channel_box])
solo_buttons, enabled_buttons, color_pickers, channel_boxes = zip(*channel_widgets)
channels_box = widgets.VBox(channel_boxes)
t_slider = widgets.IntSlider(
    label="t", min=0, max=num_timepoints, step=1, value=0, continuous_update=False
)
v_slider = widgets.IntSlider(
    min=0, max=num_fovs, step=1, value=0, continuous_update=False
)
slider_box = widgets.VBox([v_slider, t_slider])
control_box = widgets.HBox([channels_box, slider_box])
output = widgets.Output()
main_box = widgets.VBox([control_box, output])
display(main_box)

max_val = 2**14

Frame = Stream.define("Frame", t=0, v=0)
frame = Frame()
DisplaySettings = Stream.define(
    "DisplaySettings", channel_enabled=np.array([True] * n_channels)
)
display_settings = DisplaySettings()


def composite_image(t, v, channel_enabled):
    # def composite_image(t, v):
    # channel_enabled = [True] * n_channels
    # channel_imgs = [frames.get_frame_2D(c=i, t=t, v=v) for i in range(n_channels)]
    channel_imgs = [frames_z[v, c, t, :, :] for c in range(n_channels)]
    scaled_imgs = [
        channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
        for i in range(n_channels)
    ]
    for scaled_img in scaled_imgs:
        np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
    colored_imgs = [scaled_imgs[i] * np.array(colors[i]) for i in range(n_channels)]
    imgs_to_combine = [colored_imgs[i] for i in range(n_channels) if channel_enabled[i]]
    if not len(imgs_to_combine):
        imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
    img = imgs_to_combine[0]
    for img2 in imgs_to_combine[1:]:
        img = 1 - (1 - img) * (1 - img2)
    return hv.RGB(img, bounds=(-1, -1, 1, 1))  # .opts(plot={'size': 250}, tools=[''])


t_slider.observe(lambda change: frame.event(t=change["new"]), names="value")
v_slider.observe(lambda change: frame.event(v=change["new"]), names="value")


def update_enabled_channels(change):
    channel_enabled = np.array([button.value for button in enabled_buttons])
    display_settings.event(channel_enabled=channel_enabled)


def update_solo(solo_button):
    if (
        solo_button._button_to_enable.value
        and sum([b.value for b in enabled_buttons]) == 1
    ):
        for enabled_button in enabled_buttons:
            enabled_button.value = True
    else:
        for enabled_button in enabled_buttons:
            enabled_button.value = enabled_button == solo_button._button_to_enable
    # update_enabled_channels(None)


for solo_button in solo_buttons:
    solo_button.on_click(update_solo)

for enabled_button in enabled_buttons:
    enabled_button.observe(update_enabled_channels, names="value")
# for color_picker in color_pickers:
#    color_picker.observe(update_image, names='value')

# hv.DynamicMap(composite_image, kdims=['t', 'v', 'channel_enabled']).select(t=0,v=0,channel_enabled=np.array([True,False,False,False,False]))
image_viewer = hv.DynamicMap(composite_image, streams=[frame, display_settings])
regrid(image_viewer)

# Trench detection

In [None]:
def get_img_limits(img):
    x_min = y_min = 0
    x_max, y_max = img.shape
    x_lim = (x_min, x_max)
    y_lim = (y_min, y_max)
    return x_lim, y_lim


def cluster_binary_image(bin_img):
    X = np.array(np.where(bin_img)).T
    X2 = StandardScaler().fit_transform(X.astype(np.float32))
    fit = sklearn.cluster.MiniBatchKMeans(
        init="k-means++", n_clusters=2, n_init=10, max_no_improvement=10, verbose=0
    )
    fit.fit(X2)
    return X, fit


def label_binary_image(bin_img):
    X, fit = cluster_binary_image(bin_img)
    label_img = np.zeros_like(bin_img, dtype=np.int8)  # TODO: fixing dtype
    for i in range(len(fit.labels_)):
        label_img[X[i, 0], X[i, 1]] = fit.labels_[i] + 1
    return label_img


def drop_rare_labels(labels):
    counter = Counter(labels)
    total = sum(counter)
    good_labels = []
    for label, count in counter.iteritems():
        print(count / total)
        if count / total > 0.01:
            good_labels.append(label)
    return good_labels


def detect_rotation(bin_img):
    h, theta, d = skimage.transform.hough_line(bin_img)
    abs_diff_h = np.diff(h.astype(np.int32), axis=1).var(axis=0)
    theta_idx = abs_diff_h.argmax()
    angle1 = theta[theta_idx]
    h2, theta2, d2 = skimage.transform.hough_line(
        bin_img, theta=np.linspace(0.9 * angle1, 1.1 * angle1, 200)
    )
    abs_diff_h2 = np.diff(h2.astype(np.int32), axis=1).var(axis=0)
    theta_idx2 = abs_diff_h2.argmax()
    angle2 = theta2[theta_idx2]
    d_profile = h2[:, theta_idx2].astype(np.int32)
    freqs = np.abs(np.fft.fft(d_profile))
    peak_idxs = peakutils.indexes(d_profile, thres=0.4, min_dist=5)
    peaks = d2[peak_idxs]
    spacing = scipy.stats.mode(np.diff(peaks)).mode[0]
    return np.pi / 2 - angle2, peaks


def get_rough_spacing(dists):
    spacing = scipy.stats.mode(np.diff(dists).astype(int)).mode[0]
    return spacing


def point_linspace(anchor0, anchor1, num_points):
    for s in np.linspace(0, 1, num_points)[1:-1]:
        anchor = (1 - s) * anchor0 + s * anchor1
        yield anchor


def coords_along(x0, x1):
    length = int(np.sqrt(np.sum((x1 - x0) ** 2)))
    xs = np.linspace(x0[0], x1[0], length).astype(np.int_)[1:-1]
    ys = np.linspace(x0[1], x1[1], length).astype(np.int_)[1:-1]
    return xs, ys


def edge_point(x0, theta, x_lim, y_lim):
    x_min, x_max = x_lim
    y_min, y_max = y_lim
    theta = theta % (2 * np.pi)
    if 0 <= theta < np.pi / 2:
        corner_x, corner_y = x_min, y_max
    elif np.pi / 2 <= theta < np.pi:
        corner_x, corner_y = x_max, y_max
    elif np.pi <= theta < 3 / 2 * np.pi:
        corner_x, corner_y = x_max, y_min
    elif 3 / 2 * np.pi <= theta <= 2 * np.pi:
        corner_x, corner_y = x_min, y_min
    angle_to_corner = np.arctan2(corner_y - x0[1], x0[0] - corner_x) % (2 * np.pi)
    if (
        (theta >= angle_to_corner and 0 <= theta < np.pi / 2)
        or (theta < angle_to_corner and np.pi / 2 <= theta < np.pi)
        or (theta >= angle_to_corner and np.pi <= theta < 3 / 2 * np.pi)
        or (theta < angle_to_corner and 3 / 2 * np.pi <= theta < 2 * np.pi)
    ):
        # top/bottom
        x1 = np.array([x0[0] - (corner_y - x0[1]) / np.tan(theta), corner_y])
    else:
        # left/right
        x1 = np.array([corner_x, x0[1] - (corner_x - x0[0]) * np.tan(theta)])
    return x1

In [None]:
def line_array(
    anchors, theta, x_lim, y_lim, start=None, stop=None, bidirectional=False
):
    if bidirectional:
        line_array1 = line_array(
            anchors, theta, x_lim, y_lim, start=start, stop=stop, bidirectional=False
        )
        line_array2 = line_array(
            anchors,
            theta + np.pi,
            x_lim,
            y_lim,
            start=start,
            stop=stop,
            bidirectional=False,
        )
        for (x0, x1), (y0, y1) in zip(line_array1, line_array2):
            yield x0, x1, y1
        return
    if start is None:
        start = 0
    if stop is None:
        stop = 0
    if not stop >= start >= 0:
        raise ValueError("need stop >= start >= 0")
    theta = theta % (2 * np.pi)
    for anchor in anchors:
        x0 = anchor
        x1 = edge_point(x0, theta, x_lim, y_lim)
        max_length = np.sqrt(((x1 - x0) ** 2).sum())
        y0, y1 = x0, x1
        if start:
            y0 = min(start / max_length, 1) * (x1 - x0) + x0
        if stop:
            y1 = min(stop / max_length, 1) * (x1 - x0) + x0
        if not np.array_equal(y0, y1):
            yield y0, y1


def get_anchors(theta, x_lim, y_lim):
    x_min = np.array([x_lim[0], y_lim[0]])
    x_max = np.array([x_lim[1], y_lim[1]])
    x0 = x_min + (x_max - x_min) / 2
    anchor0 = edge_point(x0, theta, x_lim, y_lim)
    anchor1 = edge_point(x0, theta + np.pi, x_lim, y_lim)
    return anchor0, anchor1

In [None]:
def detect_trench_region(bin_img, theta):
    x_lim, y_lim = get_img_limits(bin_img)
    anchor0, anchor1 = get_anchors(theta, x_lim, y_lim)
    cross_sections = []
    anchors = list(point_linspace(anchor0, anchor1, 40))[3:-3]  # TODO: parameterize
    lines = list(
        line_array(anchors, np.pi / 2 + theta, x_lim, y_lim, bidirectional=True)
    )
    for x0, x1, x2 in lines:
        xs, ys = coords_along(x1, x2)
        cross_sections.append(bin_img[ys, xs])
    cross_section_vars = np.array([cs.var() for cs in cross_sections])
    idx = cross_section_vars.argmax()
    return anchors[idx]

In [None]:
# FROM: https://stackoverflow.com/questions/23815327/numpy-one-liner-for-combining-unequal-length-np-array-to-a-matrixor-2d-array
def stack_jagged(arys, fill=0):
    return np.array(list(zip_longest(*arys, fillvalue=fill))).T


def detect_periodic_peaks(signal):
    idxs = peakutils.indexes(signal, thres=0.1, min_dist=5)
    xs = peakutils.interpolate(np.arange(len(signal)), signal, ind=idxs)
    dxs = np.diff(xs)
    period_min = np.percentile(dxs, 10)
    period_max = dxs.max()
    num_periods = 100
    periods = np.linspace(period_min, period_max, num_periods)
    # std = ((dxs + periods[:,np.newaxis]/2) % periods[:,np.newaxis]).std(axis=1)
    std = scipy.stats.iqr(((xs) % periods[:, np.newaxis]), axis=1) / periods
    period_idx = std.argmin()
    period = periods[period_idx]
    plt.figure()
    plt.plot(std)
    plt.scatter([period_idx], [std[period_idx]], c="r")
    periods2 = np.linspace(period * 0.98, period * 1.02, num_periods)
    std2 = scipy.stats.iqr(((xs) % periods2[:, np.newaxis]), axis=1) / periods2
    period_idx2 = std2.argmin()
    period2 = periods2[period_idx2]
    plt.figure()
    plt.plot(std2)
    plt.scatter([period_idx2], [std2[period_idx2]], c="r")
    offsets = np.linspace(0, period2, num_periods)
    offset_idxs = (
        np.arange(0, len(signal) - period2, period2) + offsets[:, np.newaxis]
    ).astype(np.int_)
    objective = signal[offset_idxs].sum(axis=1)
    offset_idx = objective.argmax()
    offset = offsets[offset_idx]
    plt.figure()
    plt.plot(objective)
    plt.scatter([offset_idx], [objective[offset_idx]], c="r")
    return period2, offset


def detect_trench_anchors(img, t0, theta):
    x_lim, y_lim = get_img_limits(img)
    x1 = edge_point(t0, theta - np.pi / 2, x_lim, y_lim)
    x2 = edge_point(t0, theta + np.pi / 2, x_lim, y_lim)
    xs, ys = coords_along(x1, x2)
    profile = img[ys, xs]
    period, offset = detect_periodic_peaks(profile)
    idxs = np.arange(offset, len(profile), period).astype(np.int_)
    plt.figure(figsize=(16, 8))
    plt.plot(profile)
    plt.scatter(idxs, profile[idxs], c="r")
    return np.vstack((xs[idxs], ys[idxs])).T


def _detect_trench_end(img, anchors, theta):
    x_lim, y_lim = get_img_limits(img)
    xss = []
    yss = []
    trench_profiles = []
    for anchor in anchors:
        x_end = edge_point(anchor, theta, x_lim, y_lim)
        xs, ys = coords_along(anchor, x_end)
        xss.append(xs)
        yss.append(ys)
        trench_profiles.append(img[ys, xs])
    plt.figure(figsize=(8, 8))
    for trench_profile in trench_profiles:
        plt.plot(trench_profile)
    stacked_profile = np.percentile(stack_jagged(trench_profiles), 80, axis=0)
    # cum_profile = np.cumsum(stacked_profile)
    # cum_profile /= cum_profile[-1]
    # end = np.where(cum_profile > 0.8)[0][0]
    stacked_profile_diff = holo_diff(1, stacked_profile)
    end = stacked_profile_diff.argmin()
    plt.figure(figsize=(8, 8))
    plt.plot(stacked_profile)
    ax2 = plt.gca().twinx()
    ax2.plot(stacked_profile_diff, color="g")
    plt.axvline(end, c="r")
    end_points = []
    for xs, ys in zip(xss, yss):
        idx = end
        if len(xs) <= end:
            idx = -1
        end_points.append((xs[idx], ys[idx]))
    return np.array(end_points)


def detect_trench_ends(img, bin_img, anchors, theta):
    img_masked = np.where(
        skimage.morphology.binary_dilation(bin_img), img, np.percentile(img, 5)
    )
    top_points = _detect_trench_end(img_masked, anchors, theta)
    bottom_points = _detect_trench_end(img_masked, anchors, theta + np.pi)
    plt.figure(figsize=(12, 12))
    plt.imshow(img_masked)
    plt.scatter(*anchors.T, s=3, c="w")
    plt.scatter(*top_points.T, s=3, c="g")
    plt.scatter(*bottom_points.T, s=3, c="r")
    return top_points, bottom_points


def detect_trenches(img, bin_img, theta):
    t0 = detect_trench_region(bin_img, theta)
    trench_anchors = detect_trench_anchors(img, t0, theta)
    trench_points = detect_trench_ends(img, bin_img, trench_anchors, theta)
    return trench_points

# Segmentation

In [None]:
def hessian_eigenvalues(img):
    I = skimage.filters.gaussian(img, 1.5)
    I_x = skimage.filters.sobel_h(I)
    I_y = skimage.filters.sobel_v(I)
    I_xx = skimage.filters.sobel_h(I_x)
    I_xy = skimage.filters.sobel_v(I_x)
    I_yx = skimage.filters.sobel_h(I_y)
    I_yy = skimage.filters.sobel_v(I_y)
    kappa_1 = (I_xx + I_yy) / 2
    kappa_2 = (np.sqrt((I_xx + I_yy) ** 2 - 4 * (I_xx * I_yy - I_xy * I_yx))) / 2
    k1 = kappa_1 + kappa_2
    k2 = kappa_1 - kappa_2
    k1[np.isnan(k1)] = 0
    k2[np.isnan(k2)] = 0
    return k1, k2

In [None]:
def get_image_series_trenches(img_series):
    img = img_series.max(axis=0)
    # TODO: need rotation-invariant detrending
    img = img - np.percentile(img, 3, axis=1)[:, np.newaxis]
    img_thresh = img > skimage.filters.threshold_otsu(img)
    img_labels = label_binary_image(img_thresh)
    plt.figure(figsize=(8, 8))
    plt.imshow(img_labels)
    theta, dists = detect_rotation(img_labels == 1)
    trench_points = detect_trenches(img, img_labels == 1, theta)
    return trench_points

In [None]:
frames_z.shape

In [None]:
%%time
frame_series = frames_z[0, 0, :, :, :]

In [None]:
%%time
trench_points = get_image_series_trenches(frame_series)

In [None]:
def extract_kymograph(img_series, x0, x1):
    num_timepoints = img_series.shape[0]
    xs, ys = coords_along(x0, x1)
    kymo = np.zeros((len(xs), num_timepoints))
    for t in range(num_timepoints):
        kymo[:, t] = img_series[t, ys, xs]
    return kymo


def get_image_series_segmentation(img_series, x0, x1):
    pass  # list of trenches, for each trench, a list of cell masks


def map_over_segmentation(img_series, cell_seg, func):
    pass


trench_idx = 12
kymo = extract_kymograph(
    f_series, trench_points[0][trench_idx], trench_points[1][trench_idx]
)
plt.figure()
plt.imshow(kymo)

In [None]:
f1_b = skimage.filters.gaussian(f1, 3)
f1_k1, f1_k2 = hessian_eigenvalues(f1_b)

In [None]:
# img = skimage.transform.rotate(f1, 15, cval=0)

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(f1_k1)

In [None]:
plt.figure()
plt.imshow(kymo)

In [None]:
f_series = frames_z[0, 0]

In [None]:
plt.figure(figsize=(8, 8))
plt.imshow(f_series[0])