In [None]:
import functools
from collections import Counter

import datashader as ds
import holoviews as hv
import ipywidgets as widgets
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import peakutils
import scipy.interpolate
import scipy.stats
import skimage
import skimage.morphology
import sklearn
import zarr
from bokeh.io import output_notebook, push_notebook, show
from bokeh.models import WheelZoomTool
from holoviews.operation import decimate
from holoviews.operation.datashader import (
    aggregate,
    datashade,
    dynspread,
    regrid,
    shade,
)
from holoviews.streams import Stream, param
from IPython.display import clear_output, display
from ipywidgets import fixed, interact, interact_manual, interactive
from matplotlib.colors import hex2color
from numcodecs import Blosc, Delta
from sklearn.cluster import DBSCAN

# from sklearn import metrics
# from sklearn.datasets.samples_generator import make_blobs
from sklearn.preprocessing import StandardScaler

# from bokeh.layouts import row
# from bokeh.plotting import figure
from tqdm import tnrange, tqdm_notebook

In [None]:
%load_ext line_profiler
hv.notebook_extension("bokeh")
renderer = hv.renderer("bokeh")
%matplotlib inline

In [None]:
frames_z = zarr.open_array("/home/jqs1/scratch/fidelity/test/171018.zarr", mode="r")

In [None]:
frames = nd2reader.ND2Reader(
    "/home/jqs1/scratch/fidelity/171018/20171018_TrxnError_ID.nd2"
)

In [None]:
frames_z.shape

In [None]:
channel_colors = {
    "BF": "#ffffff",
    "MCHERRY": "#e22400",
    "GFP": "#76ba40",
    "CY5": "#e292fe",
    "BFP": "#3a87fd",
    "YFP": "#f5eb00",
}

In [None]:
%%output size=250

channels = frames_z.attrs["metadata"]["channels"]
n_channels = len(channels)
colors = [hex2color(channel_colors[channel]) for channel in channels]
num_timepoints = len(frames_z.attrs["metadata"]["frames"])
num_fovs = len(frames_z.attrs["metadata"]["fields_of_view"])

channel_boxes = []
channel_widgets = []
for channel in channels:
    solo_button = widgets.Button(description="S", layout=widgets.Layout(width="10%"))
    enabled_button = widgets.ToggleButton(description=channel, value=True)
    solo_button._button_to_enable = enabled_button
    color_picker = widgets.ColorPicker(concise=True, value=channel_colors[channel])
    channel_box = widgets.HBox([solo_button, enabled_button, color_picker])
    channel_widgets.append([solo_button, enabled_button, color_picker, channel_box])
solo_buttons, enabled_buttons, color_pickers, channel_boxes = zip(*channel_widgets)
channels_box = widgets.VBox(channel_boxes)
t_slider = widgets.IntSlider(
    label="t", min=0, max=num_timepoints, step=1, value=0, continuous_update=False
)
v_slider = widgets.IntSlider(
    min=0, max=num_fovs, step=1, value=0, continuous_update=False
)
slider_box = widgets.VBox([v_slider, t_slider])
control_box = widgets.HBox([channels_box, slider_box])
output = widgets.Output()
main_box = widgets.VBox([control_box, output])
display(main_box)

max_val = 2**14

Frame = Stream.define("Frame", t=0, v=0)
frame = Frame()
DisplaySettings = Stream.define(
    "DisplaySettings", channel_enabled=np.array([True] * n_channels)
)
display_settings = DisplaySettings()


def composite_image(t, v, channel_enabled):
    # def composite_image(t, v):
    # channel_enabled = [True] * n_channels
    # channel_imgs = [frames.get_frame_2D(c=i, t=t, v=v) for i in range(n_channels)]
    channel_imgs = [frames_z[v, c, t, :, :] for c in range(n_channels)]
    scaled_imgs = [
        channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
        for i in range(n_channels)
    ]
    for scaled_img in scaled_imgs:
        np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
    colored_imgs = [scaled_imgs[i] * np.array(colors[i]) for i in range(n_channels)]
    imgs_to_combine = [colored_imgs[i] for i in range(n_channels) if channel_enabled[i]]
    if not len(imgs_to_combine):
        imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
    img = imgs_to_combine[0]
    for img2 in imgs_to_combine[1:]:
        img = 1 - (1 - img) * (1 - img2)
    return hv.RGB(img, bounds=(-1, -1, 1, 1))  # .opts(plot={'size': 250}, tools=[''])


t_slider.observe(lambda change: frame.event(t=change["new"]), names="value")
v_slider.observe(lambda change: frame.event(v=change["new"]), names="value")


def update_enabled_channels(change):
    channel_enabled = np.array([button.value for button in enabled_buttons])
    display_settings.event(channel_enabled=channel_enabled)


def update_solo(solo_button):
    if (
        solo_button._button_to_enable.value
        and sum([b.value for b in enabled_buttons]) == 1
    ):
        for enabled_button in enabled_buttons:
            enabled_button.value = True
    else:
        for enabled_button in enabled_buttons:
            enabled_button.value = enabled_button == solo_button._button_to_enable
    # update_enabled_channels(None)


for solo_button in solo_buttons:
    solo_button.on_click(update_solo)

for enabled_button in enabled_buttons:
    enabled_button.observe(update_enabled_channels, names="value")
# for color_picker in color_pickers:
#    color_picker.observe(update_image, names='value')

# hv.DynamicMap(composite_image, kdims=['t', 'v', 'channel_enabled']).select(t=0,v=0,channel_enabled=np.array([True,False,False,False,False]))
image_viewer = hv.DynamicMap(composite_image, streams=[frame, display_settings])
regrid(image_viewer)

In [None]:
a_stack = frames_z[0, :, :, 300:500, :1000]
a = a_stack[0, 0]

In [None]:
%%time
f0 = frames_z[0, 0, :].max(axis=0)

In [None]:
%%time
f01 = frames_z[0, 1, :].max(axis=0)

In [None]:
f1 = f0 - np.percentile(f0, 0.5, axis=1)[:, np.newaxis]

In [None]:
f1_b = skimage.filters.gaussian(f1, 3)
f1_v = skimage.filters.sobel_v(f1_b)
f1_v2 = skimage.filters.sobel_v(f1_v)

In [None]:
def hessian_eigenvalues(img):
    I = skimage.filters.gaussian(img, 1.5)
    I_x = skimage.filters.sobel_h(I)
    I_y = skimage.filters.sobel_v(I)
    I_xx = skimage.filters.sobel_h(I_x)
    I_xy = skimage.filters.sobel_v(I_x)
    I_yx = skimage.filters.sobel_h(I_y)
    I_yy = skimage.filters.sobel_v(I_y)
    kappa_1 = (I_xx + I_yy) / 2
    kappa_2 = (np.sqrt((I_xx + I_yy) ** 2 - 4 * (I_xx * I_yy - I_xy * I_yx))) / 2
    k1 = kappa_1 + kappa_2
    k2 = kappa_1 - kappa_2
    k1[np.isnan(k1)] = 0
    k2[np.isnan(k2)] = 0
    return k1, k2

In [None]:
f1_k1, f1_k2 = hessian_eigenvalues(f1_b)

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(f1)

In [None]:
f2 = f1 > skimage.filters.threshold_otsu(f1)
# f2 = skimage.morphology.convex_hull_object(f2)
plt.figure(figsize=(20, 20))
plt.imshow(f2)

In [None]:
X = np.array(np.where(f2)).T
X2 = StandardScaler().fit_transform(X.astype(np.float32))
fit = sklearn.cluster.MiniBatchKMeans(
    init="k-means++", n_clusters=2, n_init=10, max_no_improvement=10, verbose=0
)

In [None]:
%%time
fit.fit(X2)

In [None]:
label_img = np.zeros_like(f1)
for i in range(len(fit.labels_)):
    label_img[X[i, 0], X[i, 1]] = fit.labels_[i] + 1

In [None]:
counter = Counter(fit.labels_)
total = sum(counter.values())
good_labels = []
for label, count in counter.items():
    print(count / total)
    if count / total > 0.01:
        good_labels.append(label)

In [None]:
good_labels

# Trench detection

In [None]:
label_img_rot = skimage.transform.rotate(label_img, 15, cval=0)  # , resize=True)
plt.figure(figsize=(20, 20))
plt.imshow(label_img_rot)

In [None]:
def detect_trenches2(thresholded_img):
    h, theta, d = skimage.transform.hough_line(thresholded_img)
    abs_diff_h = np.diff(h.astype(np.int32), axis=1).var(axis=0)
    theta_idx = abs_diff_h.argmax()
    angle1 = theta[theta_idx]
    h2, theta2, d2 = skimage.transform.hough_line(
        thresholded_img, theta=np.linspace(0.9 * angle1, 1.1 * angle1, 200)
    )
    abs_diff_h2 = np.diff(h2.astype(np.int32), axis=1).var(axis=0)
    theta_idx2 = abs_diff_h2.argmax()
    angle2 = theta2[theta_idx2]
    d_profile = h2[:, theta_idx2].astype(np.int32)
    # plt.figure(figsize=(8,8))
    # plt.imshow(h, aspect=0.1)
    # plt.plot(d_profile[3000:3300])
    # plt.figure(figsize=(8,8))
    freqs = np.abs(np.fft.fft(d_profile))
    # plt.plot(freqs)
    peak_idxs = peakutils.indexes(d_profile, thres=0.4)
    peaks = d2[peak_idxs]
    spacing = scipy.stats.mode(np.diff(peaks)).mode[0]
    # print(spacing)
    # print(peaks[0], peaks[-1], angle2)
    # return angle2, peaks
    plt.figure(figsize=(20, 12))
    plt.imshow(thresholded_img)
    angle = angle2
    y_min = 0
    y_max = thresholded_img.shape[1]
    # peak0 = peaks[0]
    # peak1 = peaks[-1]
    # peaks = np.linspace(peak0, peak1, (peak1-peak0) // spacing-2)
    # trench_profiles = []
    for dist in peaks:
        # print(angle, dist)
        y0 = (dist - y_min * np.cos(angle)) / np.sin(angle)
        y1 = (dist - y_max * np.cos(angle)) / np.sin(angle)
        plt.plot((0, y_max), (y0, y1), "-r")
    #    length = int(np.hypot(y_max - 0, y1 - y0))
    #    xs = np.linspace(0, y_max-1, length).astype(np.int)
    #    ys = np.linspace(y0, y1-1, length).astype(np.int)
    #    #trench_profiles.append(img[xs, ys])
    plt.xlim((0, thresholded_img.shape[1]))
    plt.ylim((thresholded_img.shape[0], 0))
    # plt.figure(figsize=(8,8))
    # plt.plot(trench_profiles)


# detect_trenches2(img_labels == 1)

In [None]:
def cluster_binary_image(bin_img):
    X = np.array(np.where(bin_img)).T
    X2 = StandardScaler().fit_transform(X.astype(np.float32))
    fit = sklearn.cluster.MiniBatchKMeans(
        init="k-means++", n_clusters=2, n_init=10, max_no_improvement=10, verbose=0
    )
    fit.fit(X2)
    return X, fit


def label_binary_image(bin_img):
    X, fit = cluster_binary_image(bin_img)
    label_img = np.zeros_like(bin_img)
    for i in range(len(fit.labels_)):
        label_img[X[i, 0], X[i, 1]] = fit.labels_[i] + 1
    return label_img


def drop_rare_labels(labels):
    total = fit.labels_.sum()
    good_labels = []
    for label, count in enumerate(fit.labels_):
        print(count / total)
        if count / total > 0.01:
            good_labels.append(label)
    return good_labels


def detect_trenches(bin_img):
    h, theta, d = skimage.transform.hough_line(bin_img)
    abs_diff_h = np.diff(h.astype(np.int32), axis=1).var(axis=0)
    theta_idx = abs_diff_h.argmax()
    angle1 = theta[theta_idx]
    h2, theta2, d2 = skimage.transform.hough_line(
        bin_img, theta=np.linspace(0.9 * angle1, 1.1 * angle1, 200)
    )
    abs_diff_h2 = np.diff(h2.astype(np.int32), axis=1).var(axis=0)
    theta_idx2 = abs_diff_h2.argmax()
    angle2 = theta2[theta_idx2]
    d_profile = h2[:, theta_idx2].astype(np.int32)
    freqs = np.abs(np.fft.fft(d_profile))
    peak_idxs = peakutils.indexes(d_profile, thres=0.4, min_dist=5)
    peaks = d2[peak_idxs]
    spacing = scipy.stats.mode(np.diff(peaks)).mode[0]
    return angle2, peaks


def get_rough_spacing(dists):
    spacing = scipy.stats.mode(np.diff(dists).astype(int)).mode[0]
    return spacing


def point_linspace(anchor0, anchor1, num_points):
    for s in np.linspace(0, 1, num_points)[1:-1]:
        anchor = (1 - s) * anchor0 + s * anchor1
        yield anchor


def get_anchors(theta, x_lim, y_lim):
    x_min, x_max = x_lim
    y_min, y_max = y_lim
    if 0 <= (theta % np.pi) < np.pi / 4 or 3 / 4 * np.pi <= (theta % np.pi) < np.pi:
        y0 = (y_max - y_min) / 2 + y_min
        dy = (x_max - x_min) / 2 * np.tan(theta)
        anchor0 = np.array([x_min, y0 - dy])
        anchor1 = np.array([x_max, y0 + dy])
    return anchor0, anchor1


def coords_along(x0, x1):
    length = int(np.sqrt(np.sum((x1 - x0) ** 2)))
    xs = np.linspace(x0[0], x1[0], length).astype(np.int_)[1:-1]
    ys = np.linspace(x0[1], x1[1], length).astype(np.int_)[1:-1]
    return xs, ys

In [None]:
def line_array(anchors, theta, x_lim, y_lim, start=None, stop=None):
    if start is None:
        start = 0
    if stop is None:
        stop = 0
    if not stop >= start >= 0:
        raise ValueError("need stop >= start >= 0")
    for anchor in anchors:
        x0 = anchor
        if -3 / 4 * np.pi < theta < np.pi / 4:
            x, y = x_min, y_min
            angle_to_corner = np.arctan2(-anchor[0], anchor[1])
        else:
            x, y = x_max, y_max
            angle_to_corner = -np.arctan2(anchor[0] - x_max, -(y_max - anchor[1]))
        if angle_to_corner < theta:
            # endpoint at top/bottom
            x1 = np.array([anchor[0] + (anchor[1] - y) * np.tan(theta), y])
            # length = (anchor[1] - y)/np.cos(theta)
            # x1 = np.array([anchor[0] + length*np.sin(theta), y])
        else:
            # endpoint on side
            x1 = np.array([x, anchor[1] + (anchor[0] - x) / np.tan(theta)])
        max_length = np.sqrt(((x1 - x0) ** 2).sum())
        y0, y1 = x0, x1
        if start:
            y0 = min(start / max_length, 1) * (x1 - x0) + x0
        if stop:
            y1 = min(stop / max_length, 1) * (x1 - x0) + x0
        if not np.array_equal(y0, y1):
            yield y0, y1


def point_linspace_fine(anchor0, anchor1, spacing, offset):
    for s in np.linspace(0, 1, N)[1:-1]:
        anchor = (1 - s) * anchor0 + s * anchor1
        yield anchor


def refine_trenches(img, theta, spacing):
    x_min = y_min = 0
    x_max, y_max = img.shape
    x_lim = (x_min, x_max)
    y_lim = (y_min, y_max)
    anchor0, anchor1 = get_anchors(theta, x_lim, y_lim)

    def objective(spacing_, offset):
        s = 0
        for x0, x1 in line_array(
            point_linspace_fine(anchor0, anchor1, spacing_, offset), theta, x_lim, y_lim
        ):
            xs, ys = coords_along(x0, x1)
            s += img[ys, xs].sum()
        return s

    spacings = np.linspace(0.8 * spacing, 1.2 * spacing, 20)
    offsets = np.linspace(0, 1.2 * spacing, 20)


img = skimage.transform.rotate(f1, 15, cval=0)
img_thresh = img > skimage.filters.threshold_otsu(img)
img_labels = label_binary_image(img_thresh)
theta, dists = detect_trenches(img_labels == 1)
spacing = get_rough_spacing(dists)
# anchors
x_min = y_min = 0
x_max, y_max = img.shape
x_lim = (x_min, x_max)
y_lim = (y_min, y_max)
anchor0, anchor1 = get_anchors(theta, x_lim, y_lim)
plt.figure(figsize=(12, 12))
plt.imshow(img)
plt.gca().add_artist(plt.Circle(anchor0, 50, color="g"))
plt.gca().add_artist(plt.Circle(anchor1, 50, color="gray"))
trench_profiles = []
for x0, x1 in line_array(
    point_linspace(anchor0, anchor1, int((anchor1[0] - anchor0[0]) // spacing)),
    theta,
    x_lim,
    y_lim,
    start=500,
    stop=700,
):
    xs, ys = coords_along(x0, x1)
    trench_profiles.append(img[ys, xs])
    line = np.vstack((x0, x1)).T
    plt.plot(*line, color="w")
    plt.gca().add_artist(plt.Circle(x0, 10, color="r", zorder=2))
    plt.gca().add_artist(plt.Circle(x1, 10, color="r", zorder=2))
plt.figure(figsize=(12, 12))
for trench_profile in trench_profiles:
    plt.plot(trench_profile)

In [None]:
x = ((dists - dists[:, np.newaxis]) % 24.001).flat
plt.hist(x, bins=50)

In [None]:
plt.hist(np.diff(dists), bins=50)

# Old

In [None]:
kt = k2 > 0.5 * skimage.filters.threshold_otsu(k2)
plt.figure(figsize=(20, 12))
plt.imshow(kt)

In [None]:
kt = k1 > 0.5 * skimage.filters.threshold_otsu(k1)
plt.figure(figsize=(20, 12))
plt.imshow(kt)

In [None]:
# rough trench finding using linear hough on thresholded k2 or thresholded intensity (?)
# oval hough transform to find cells

In [None]:
plt.imshow(skimage.measure.label(1 - kt))

In [None]:
skimage.measure.label(1 - kt)

In [None]:
num_timesteps = a_stack.shape[1]
play = widgets.Play(
    # interval=10,
    value=0,
    min=0,
    max=num_timesteps,
    step=1,
    description="Press play",
    disabled=False,
)
time_slider = widgets.IntSlider(min=0, max=num_timesteps, continuous_update=False)
widgets.jslink((play, "value"), (time_slider, "value"))
output = widgets.Output()
box = widgets.VBox([widgets.HBox([play, time_slider]), output])


def f(t):
    with output:
        z = hessian_eigenvalues(a_stack[0, t])[0]
        clear_output(wait=True)
        plt.figure(figsize=(20, 12))
        plt.imshow(z)
        plt.show()


interactive(f, t=time_slider)
box