In [None]:
import functools
from collections import Counter

import datashader as ds
import holoviews as hv
import ipywidgets as widgets
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import skimage
import skimage.morphology
import sklearn
import zarr
from bokeh.io import output_notebook, push_notebook, show
from bokeh.models import WheelZoomTool
from holoviews.operation import decimate
from holoviews.operation.datashader import (
    aggregate,
    datashade,
    dynspread,
    regrid,
    shade,
)
from holoviews.streams import Stream, param
from IPython.display import clear_output, display
from ipywidgets import fixed, interact, interact_manual, interactive
from matplotlib.colors import hex2color
from numcodecs import Blosc, Delta
from sklearn.cluster import DBSCAN

# from sklearn import metrics
# from sklearn.datasets.samples_generator import make_blobs
from sklearn.preprocessing import StandardScaler

# from bokeh.layouts import row
# from bokeh.plotting import figure
from tqdm import tnrange, tqdm_notebook

In [None]:
%load_ext line_profiler
hv.notebook_extension("bokeh")
renderer = hv.renderer("bokeh")
%matplotlib inline

In [None]:
frames_z = zarr.open_array("/home/jqs1/scratch/fidelity/test/171018.zarr", mode="r")

In [None]:
frames = nd2reader.ND2Reader(
    "/home/jqs1/scratch/fidelity/171018/20171018_TrxnError_ID.nd2"
)

In [None]:
frames_z.shape

In [None]:
channel_colors = {
    "BF": "#ffffff",
    "MCHERRY": "#e22400",
    "GFP": "#76ba40",
    "CY5": "#e292fe",
    "BFP": "#3a87fd",
    "YFP": "#f5eb00",
}

In [None]:
%%output size=250

channels = frames_z.attrs["metadata"]["channels"]
n_channels = len(channels)
colors = [hex2color(channel_colors[channel]) for channel in channels]
num_timepoints = len(frames_z.attrs["metadata"]["frames"])
num_fovs = len(frames_z.attrs["metadata"]["fields_of_view"])

channel_boxes = []
channel_widgets = []
for channel in channels:
    solo_button = widgets.Button(description="S", layout=widgets.Layout(width="10%"))
    enabled_button = widgets.ToggleButton(description=channel, value=True)
    solo_button._button_to_enable = enabled_button
    color_picker = widgets.ColorPicker(concise=True, value=channel_colors[channel])
    channel_box = widgets.HBox([solo_button, enabled_button, color_picker])
    channel_widgets.append([solo_button, enabled_button, color_picker, channel_box])
solo_buttons, enabled_buttons, color_pickers, channel_boxes = zip(*channel_widgets)
channels_box = widgets.VBox(channel_boxes)
t_slider = widgets.IntSlider(
    label="t", min=0, max=num_timepoints, step=1, value=0, continuous_update=False
)
v_slider = widgets.IntSlider(
    min=0, max=num_fovs, step=1, value=0, continuous_update=False
)
slider_box = widgets.VBox([v_slider, t_slider])
control_box = widgets.HBox([channels_box, slider_box])
output = widgets.Output()
main_box = widgets.VBox([control_box, output])
display(main_box)

max_val = 2**14

Frame = Stream.define("Frame", t=0, v=0)
frame = Frame()
DisplaySettings = Stream.define(
    "DisplaySettings", channel_enabled=np.array([True] * n_channels)
)
display_settings = DisplaySettings()


def composite_image(t, v, channel_enabled):
    # def composite_image(t, v):
    # channel_enabled = [True] * n_channels
    # channel_imgs = [frames.get_frame_2D(c=i, t=t, v=v) for i in range(n_channels)]
    channel_imgs = [frames_z[v, c, t, :, :] for c in range(n_channels)]
    scaled_imgs = [
        channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
        for i in range(n_channels)
    ]
    for scaled_img in scaled_imgs:
        np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
    colored_imgs = [scaled_imgs[i] * np.array(colors[i]) for i in range(n_channels)]
    imgs_to_combine = [colored_imgs[i] for i in range(n_channels) if channel_enabled[i]]
    if not len(imgs_to_combine):
        imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
    img = imgs_to_combine[0]
    for img2 in imgs_to_combine[1:]:
        img = 1 - (1 - img) * (1 - img2)
    return hv.RGB(img, bounds=(-1, -1, 1, 1))  # .opts(plot={'size': 250}, tools=[''])


t_slider.observe(lambda change: frame.event(t=change["new"]), names="value")
v_slider.observe(lambda change: frame.event(v=change["new"]), names="value")


def update_enabled_channels(change):
    channel_enabled = np.array([button.value for button in enabled_buttons])
    display_settings.event(channel_enabled=channel_enabled)


def update_solo(solo_button):
    if (
        solo_button._button_to_enable.value
        and sum([b.value for b in enabled_buttons]) == 1
    ):
        for enabled_button in enabled_buttons:
            enabled_button.value = True
    else:
        for enabled_button in enabled_buttons:
            enabled_button.value = enabled_button == solo_button._button_to_enable
    # update_enabled_channels(None)


for solo_button in solo_buttons:
    solo_button.on_click(update_solo)

for enabled_button in enabled_buttons:
    enabled_button.observe(update_enabled_channels, names="value")
# for color_picker in color_pickers:
#    color_picker.observe(update_image, names='value')

# hv.DynamicMap(composite_image, kdims=['t', 'v', 'channel_enabled']).select(t=0,v=0,channel_enabled=np.array([True,False,False,False,False]))
image_viewer = hv.DynamicMap(composite_image, streams=[frame, display_settings])
regrid(image_viewer)

In [None]:
a_stack = frames_z[0, :, :, 300:500, :1000]
a = a_stack[0, 0]

In [None]:
plt.figure(figsize=(20, 12))
plt.imshow(a)

In [None]:
%%output size=250
b = skimage.filters.scharr(a)
plt.figure(figsize=(20, 12))
plt.imshow(b)
# hv.Image(b)

In [None]:
b2 = skimage.filters.gaussian(a, sigma=2)

In [None]:
c_h = skimage.filters.sobel_h(b2)
c_v = skimage.filters.sobel_v(b2)
plt.figure(figsize=(20, 12))
plt.imshow(c_h)
plt.figure(figsize=(20, 12))
plt.imshow(c_v)

In [None]:
c2_h = skimage.filters.sobel_h(c_h)
c2_v = skimage.filters.sobel_v(c_v)
plt.figure(figsize=(20, 12))
plt.imshow(c2_h)
plt.figure(figsize=(20, 12))
plt.imshow(c2_v)

In [None]:
I = skimage.filters.gaussian(a, 1.5)
I_x = skimage.filters.sobel_h(I)
I_y = skimage.filters.sobel_v(I)
I_xx = skimage.filters.sobel_h(I_x)
I_xy = skimage.filters.sobel_v(I_x)
I_yx = skimage.filters.sobel_h(I_y)
I_yy = skimage.filters.sobel_v(I_y)
kappa_1 = (I_xx + I_yy) / 2
kappa_2 = (np.sqrt((I_xx + I_yy) ** 2 - 4 * (I_xx * I_yy - I_xy * I_yx))) / 2
k1 = kappa_1 + kappa_2
k2 = kappa_1 - kappa_2
k1[np.isnan(k1)] = 0  # np.percentile(k1, 90)
k2[np.isnan(k2)] = 0  # np.percentile(k2, 90)
plt.figure(figsize=(20, 12))
plt.imshow(k1)
plt.figure(figsize=(20, 12))
plt.imshow(k2)

In [None]:
plt.figure(figsize=(20, 12))
plt.imshow(a_stack[1, :].max(axis=1))

In [None]:
plt.figure(figsize=(20, 12))
plt.imshow(np.percentile(a_stack[0, :], 90, axis=1))

In [None]:
plt.figure(figsize=(20, 12))
plt.imshow(np.percentile(a_stack[0, :], 99, axis=1))

In [None]:
plt.figure(figsize=(20, 12))
plt.imshow(np.max(a_stack[0, :], axis=1))

In [None]:
e0 = a_stack[0, :].max(axis=1).mean(axis=0)
plt.plot(e0)

In [None]:
e0 = a_stack[0, :].max(axis=1).max(axis=0)
plt.plot(e0)

In [None]:
e1 = np.mean(k2, axis=0)
plt.plot(e1)

In [None]:
e2 = np.fft.fft(e1)
plt.plot(np.abs(e2))

In [None]:
plt.plot(np.abs(e2)[:200])

In [None]:
frames_z.shape

In [None]:
%%time
f0 = frames_z[0, 0, :].max(axis=0)

In [None]:
%%time
f01 = frames_z[0, 1, :].max(axis=0)

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(f0)

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(f01)

In [None]:
plt.plot(np.min(f0, axis=1))

In [None]:
plt.plot(np.percentile(f0, 0.5, axis=1) - np.min(f0, axis=1))

In [None]:
plt.plot(np.percentile(f0, 90, axis=1) - np.percentile(f0, 0.5, axis=1))

In [None]:
f1 = f0 - np.percentile(f0, 0.5, axis=1)[:, np.newaxis]
plt.figure(figsize=(20, 20))
plt.imshow(f1)

In [None]:
f2 = f1 > skimage.filters.threshold_otsu(f1)
plt.figure(figsize=(20, 20))
plt.imshow(f2)

In [None]:
centers = [[1, 1], [-1, -1], [1, -1]]
X, labels_true = make_blobs(
    n_samples=750, centers=centers, cluster_std=0.4, random_state=0
)

In [None]:
X

In [None]:
X = np.array(np.where(f2)).T

In [None]:
X2 = StandardScaler().fit_transform(X)

In [None]:
db = DBSCAN(eps=0.3, min_samples=10).fit(X2)

In [None]:
X[0, 0]

In [None]:
label_img = np.zeros_like(f1)
for i in range(len(db.labels_)):
    label_img[X[i, 0], X[i, 1]] = db.labels_[i]

In [None]:
counter = Counter(db.labels_)
total = sum(counter.values())
good_labels = []
for label, count in counter.items():
    print(count / total)
    if count / total > 0.01:
        good_labels.append(label)

In [None]:
sum(counter.values())

In [None]:
good_labels

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(label_img)

In [None]:
f3 = skimage.filters.gaussian(f1, 30)
# f2 = skimage.f1 > skimage.filters.threshold_otsu(f1)
plt.figure(figsize=(20, 20))
plt.imshow(f3)

In [None]:
w = skimage.morphology.watershed()
plt.figure(figsize=(20, 12))
plt.imshow(w)

In [None]:
?skimage.filters.threshold_otsu

In [None]:
kt = k2 > 0.5 * skimage.filters.threshold_otsu(k2)
plt.figure(figsize=(20, 12))
plt.imshow(kt)

In [None]:
kt = k1 > 0.5 * skimage.filters.threshold_otsu(k1)
plt.figure(figsize=(20, 12))
plt.imshow(kt)

In [None]:
# rough trench finding using linear hough on thresholded k2 or thresholded intensity (?)
# oval hough transform to find cells

In [None]:
plt.imshow(skimage.measure.label(1 - kt))

In [None]:
skimage.measure.label(1 - kt)

In [None]:
def hessian_eigenvalues(img):
    I = skimage.filters.gaussian(img, 1.5)
    I_x = skimage.filters.sobel_h(I)
    I_y = skimage.filters.sobel_v(I)
    I_xx = skimage.filters.sobel_h(I_x)
    I_xy = skimage.filters.sobel_v(I_x)
    I_yx = skimage.filters.sobel_h(I_y)
    I_yy = skimage.filters.sobel_v(I_y)
    kappa_1 = (I_xx + I_yy) / 2
    kappa_2 = (np.sqrt((I_xx + I_yy) ** 2 - 4 * (I_xx * I_yy - I_xy * I_yx))) / 2
    k1 = kappa_1 + kappa_2
    k2 = kappa_1 - kappa_2
    return k1, k2

In [None]:
num_timesteps = a_stack.shape[1]
play = widgets.Play(
    # interval=10,
    value=0,
    min=0,
    max=num_timesteps,
    step=1,
    description="Press play",
    disabled=False,
)
time_slider = widgets.IntSlider(min=0, max=num_timesteps, continuous_update=False)
widgets.jslink((play, "value"), (time_slider, "value"))
output = widgets.Output()
box = widgets.VBox([widgets.HBox([play, time_slider]), output])


def f(t):
    with output:
        z = hessian_eigenvalues(a_stack[0, t])[0]
        clear_output(wait=True)
        plt.figure(figsize=(20, 12))
        plt.imshow(z)
        plt.show()


interactive(f, t=time_slider)
box