# Imports

In [None]:
import numpy as np
import pandas as pd
import zarr
import xarray as xr
import holoviews as hv
import datashader as ds
from holoviews.operation.datashader import (
    aggregate,
    datashade,
    dynspread,
    shade,
    regrid,
)
from holoviews.operation import decimate
from holoviews.streams import Stream, param
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display, clear_output
import skimage
import skimage.morphology
import peakutils
import scipy.stats
import scipy.interpolate
import functools
import operator
from functools import partial, reduce
import itertools
from itertools import zip_longest
from more_itertools import rstrip
from collections import Counter, defaultdict
from collections.abc import Mapping, Sequence
from bokeh.models import WheelZoomTool
from bokeh.io import push_notebook, show, output_notebook

# from bokeh.layouts import row
# from bokeh.plotting import figure
import tqdm
from tqdm import tnrange, tqdm, tqdm_notebook
from copy import copy, deepcopy
import warnings

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from processing import *
from trench_detection import *
from trench_detection import _label_for_trenches
from trench_segmentation import *
from util import *
from ui import *

In [None]:
%load_ext line_profiler
hv.notebook_extension("bokeh")
# renderer = hv.renderer('bokeh')
%matplotlib inline
tqdm.monitor_interval = 0

# Config

In [None]:
# warnings.filterwarnings('ignore', 'Conversion of the second argument of issubdtype')
warnings.filterwarnings("ignore", category=FutureWarning)
# warnings.filterwarnings('default', category=FutureWarning)

In [None]:
# store = zarr.DirectoryStore('/home/jqs1/scratch/fidelity/171018/171018.zarr')
store = zarr.DirectoryStore(
    "/home/jqs1/scratch/fidelity/171214/transcriptionerror_timelapse.zarr"
)
root_group = zarr.open_group(store=store)

# Trench detection

In [None]:
diag_pos = tree()
trench_points_pos = {
    pos: fail_silently(
        lambda: get_trenches(
            root_group["raw"][str(pos)][0, 30], diagnostics=diag_pos[pos]
        )
    )
    for pos in tnrange(50)
}

In [None]:
none_pos = [k for k, v in trench_points_pos.items() if v is None]

In [None]:
good_pos = [
    k
    for k, v in diag_pos.items()
    if trench_points_pos[k] is not None
    and -5 < v["label_1"]["trench_rotation"]["hough_2"]["angle (deg)"] < 5
    and 21 < v["label_1"]["trench_anchors"]["periodogram_2"]["period"] < 26
]

In [None]:
len(good_pos)

In [None]:
rot = {
    k: v["trench_rotation"]["hough_2"]["angle (deg)"]
    for k, v in diag_pos.items()
    if k not in none_pos
}

In [None]:
plt.hist(rot.values())

In [None]:
spacings1 = {
    k: v["label_1"]["trench_anchors"]["periodogram_2"]["period"]
    for k, v in diag_pos.items()
    if k not in none_pos
}

In [None]:
plt.hist([v for v in spacings1.values() if v < 50], bins=50)

In [None]:
rot_spacing = {
    k: (
        v["trench_rotation"]["hough_2"]["angle (deg)"],
        v["label_1"]["trench_anchors"]["periodogram_2"]["period"],
    )
    for k, v in diag_pos.items()
    if k not in none_pos
}

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(*zip(*rot_spacing.values()), s=1)

In [None]:
diag_pos[0]

In [None]:
diag_pos[0]["trench_rotation"]["hough_2"]["angle (deg)"]

In [None]:
len(rot1)

In [None]:
none_pos

In [None]:
diag = tree()
get_trenches(root_group["raw"][str(0)][0, 30], diagnostics=diag)

In [None]:
display_plot_browser(diag)

In [None]:
frame_stream = FrameStream()
display(frame_browser(root_group["quantized"], frame_stream))
display_plot_browser(lambda t, v: diag_pos[v] if v in diag_pos else None, frame_stream)

In [None]:
big_image_viewer(root_group["quantized"])

# Memory debugging

In [None]:
import cloudpickle

In [None]:
import resource

resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024

In [None]:
resource?

In [None]:
resource.getrusage(resource.RUSAGE_SELF)

In [None]:
import objgraph

In [None]:
objgraph.most_common_types()

In [None]:
len(b)

In [None]:
import gc

gc.collect()

In [None]:
from pympler import asizeof

In [None]:
diag_pos[0]["label_1"]["trench_ends"].keys()

In [None]:
c = cloudpickle.dumps(diag_pos[0]["label_1"]["trench_ends"])

In [None]:
with open("foo", "wb") as f:
    cloudpickle.dump(diag_pos, f)

In [None]:
len(c)

In [None]:
print(asizeof.asized(diag_pos[0], detail=5).format())

In [None]:
%store -r

In [None]:
%store trench_points_pos
%store diag_pos

# Trench segmentation

In [None]:
pos = str(0)
n_channels, n_timepoints = root_group["raw"][pos].shape[:2]
channel = 0

In [None]:
diag_td = tree()
trench_points = get_trenches(root_group["raw"][pos][0, 30], diagnostics=diag_td)

In [None]:
# display_plot_browser(diag_td);

In [None]:
sharpness = [
    image_sharpness(root_group["raw"][pos][0, t]) for t in tnrange(n_timepoints)
]

In [None]:
trench_idx = 18
trench_set_idx = 1

In [None]:
x_lim, y_lim = get_img_limits(root_group["raw"][pos].shape[2:])
ul, lr = get_trench_bbox(trench_points[trench_set_idx], trench_idx, x_lim, y_lim)
thumbs = root_group["raw"][pos][0, :, ul[1] : lr[1], ul[0] : lr[0]]

In [None]:
thumbs_k1 = map_ndarray(lambda img: hessian_eigenvalues(img)[0], thumbs)
thumbs_k2 = map_ndarray(lambda img: hessian_eigenvalues(img)[1], thumbs)

In [None]:
t_max = 128  # thumbs_k1.shape[0]


def plot_trench(t):
    p = (
        RevImage(thumbs[t])
        + RevImage(thumbs_k1[t])
        + RevImage(thumbs_k2[t])
        # + hv.Curve(profiles[t], kdims=['y'], vdims=['i'])
        + hv.Curve(profiles[t], kdims=["y"], vdims=["j"])
    ).opts(
        plot={
            "Image": {
                "height": thumbs.shape[1] * 2,
                "width": thumbs.shape[2] * 3,
                "aspect": 0.2,
            },
            "Curve": {
                "invert_axes": True,
                "height": thumbs.shape[1] * 2,
                "width": thumbs.shape[2] * 3,
                "invert_yaxis": True,
            },
            "Layout": {"normalize": False},
        }
    )
    return p


hv.HoloMap({t: plot_trench(t) for t in range(t_max)})
# plot_trench(0)

In [None]:
def visualize_trench(img1, img2, img3, traces=None):
    # img_k1, img_k2 = hessian_eigenvalues(img)
    #     pointery = hv.streams.PointerY(y=0)
    #     def scrubber_callback(y):
    #         y = int(y)
    #         h_line = hv.HLine(y + 0.5)(style={'color': 'black', 'alpha': 0.9, 'line_width': 2})
    #         return h_line
    #     scrubber_dmap = hv.DynamicMap(scrubber_callback, streams=[pointery])
    # plot = RevImage(img) * scrubber_dmap + RevImage(img_k1) * scrubber_dmap + RevImage(img_k2) * scrubber_dmap
    plot = RevImage(img1) + RevImage(img2) + RevImage(img3)
    if traces is not None:
        plot += hv.Overlay.from_values(
            [
                hv.Curve(trace / abs(trace).max(), kdims=["y"], vdims=["i" + str(i)])
                for i, trace in enumerate(traces)
            ]
        )
    plot = plot.opts(
        plot={
            "Image": {
                "height": thumbs.shape[1] * 2,
                "width": thumbs.shape[2] * 3,
                "aspect": 0.2,
            },
            "Curve": {
                "invert_axes": True,
                "height": thumbs.shape[1] * 2,
                "width": thumbs.shape[2] * 3,
                "invert_yaxis": True,
            },
            "Layout": {"normalize": False},
        }
    )
    return plot


# visualize_trench(a)#, [a.sum(axis=1)])

In [None]:
# a = thumbs_k1[0]
a = thumbs[61]  # 31, 36, 61
a_k1, a_k2 = hessian_eigenvalues(a)

In [None]:
a.shape

In [None]:
threshold = skimage.filters.threshold_otsu(a)
a_thresh = a > threshold
plt.imshow(a_thresh)

In [None]:
plt.plot(a.sum(axis=1))

In [None]:
a_k1d = np.diff(a_k1, axis=0)

In [None]:
visualize_trench(
    a, a_k1, a_k1d, [a_k1[:, 25], a.sum(axis=1), a_k1d[:, 20:30].sum(axis=1)]
)

In [None]:
visualize_trench(a, a_k2, a_k1d, [a_k2[:, 22], a_k2[:, 25], a_k2[:, 28]])

In [None]:
plt.figure(figsize=(12, 12))
for i in range(20, 31):
    plt.plot(a_k2[:, i], label=str(i))
plt.plot(a_k2[:, :25].max(axis=1), label="max r", linestyle="--", lw=2)
plt.plot(a_k2[:, 25:].max(axis=1), label="max l", linestyle="--", lw=2)
plt.legend()

In [None]:
plt.plot(a_thresh.sum(axis=1))

In [None]:
plt.plot(a[135])

In [None]:
plt.plot(a[150])

In [None]:
plt.imshow(scipy.signal.correlate(a, a, "same"))
# plt.plot(np.correlate(, axis=0))

In [None]:
trench_points[trench_set_idx][0].shape

In [None]:
x0, x1 = (
    trench_points[trench_set_idx][0][trench_idx],
    trench_points[trench_set_idx][1][trench_idx],
)
xs, ys = coords_along(x0, x1)
profiles = np.array([thumbs_k1[t, ys - ul[1], xs - ul[0]] for t in range(n_timepoints)])

In [None]:
RevImage(thumbs[0]) * hv.Points([x0 - ul, x1 - ul]).opts(
    plot={"size": 30, "color": "green"}
)

In [None]:
hv.help(hv.Curve)

In [None]:
%%opts Layout [normalize=False]
t_max = 2


def plot_trench(t):
    p = (
        RevImage(thumbs[t])
        + RevImage(thumbs_k1[t])
        + hv.Curve(profiles[t], kdims=["y"], vdims=["i"]).opts(plot={"swap_axes": True})
    ).opts(
        plot={
            "Image": {
                "height": thumbs.shape[1] * 2,
                "width": thumbs.shape[2] * 3,
                "aspect": 0.2,
            }
        }
    )
    return p


# m = hv.HoloMap({t: plot_trench(t) for t in range(t_max)})
# m
plot_trench(0)

## Segmentation algorithm testing

In [None]:
from skimage.segmentation import morphological_geodesic_active_contour

In [None]:
morphological_geodesic_active_contour?

In [None]:
f = thumbs[0]
image = f / f.max()

# Initial level set
init_ls = np.zeros(image.shape, dtype=np.int8)
init_ls[10:-10, 10:-10] = 1
ls = morphological_geodesic_active_contour(
    image, 600, smoothing=2, threshold="auto", init_level_set=init_ls
)
plt.imshow(ls)
# List with intermediate results for plotting the evolution
# evolution = []
# callback = store_evolution_in(evolution)
# ls = morphological_geodesic_active_contour(gimage, 230, init_ls,
#                                           smoothing=1, balloon=-1,
#                                           threshold=0.69,
#                                           iter_callback=callback)

In [None]:
morphological_geodesic_active_contour?

## Kymo viewer

In [None]:
%%opts Image (cmap='viridis')
def scrubber_callback(t, y):
    t = int(t)
    v_line = hv.VLine(t + 0.5)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    t_label = hv.Text(t + kymo.shape[1] / 20, kymo.shape[0] / 20, "t={:}".format(t))(
        style={"color": "white", "text_alpha": 0.8}
    )
    h_line = hv.HLine(y)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    return v_line * t_label * h_line


def cross_section_callback(t, y):
    y = int(y)
    h_line = hv.HLine(y)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    return h_line


def sharpness_scrubber_callback(t):
    t = int(t)
    v_line = hv.VLine(t + 0.5)(style={"color": "black", "alpha": 0.3, "line_width": 2})
    return v_line


def trench_thumbnail_callback(img_series, t):
    t = int(t)
    img = get_trench_thumbnail(img_series[t], trench_points, trench_idx)
    # TODO: don't know why calling hv.Image.opts(plot={'invert_axes': True}) doesn't work
    # also switched HLine to VLine in cross_section, above
    thumb = hv.Image(
        img[::-1], bounds=(0, 0, img.shape[1], img.shape[0]), kdims=["x2", "y"]
    )
    # return thumb
    # return thumb.opts(plot={'invert_axes': False, 'width': max_thumbnail_width})#, 'xaxis': None, 'yaxis': None})
    return thumb.opts(
        plot={"invert_axes": False, "invert_yaxis": True, "width": 200}
    )  # , 'xaxis': None, 'yaxis': None})
    # return thumb.opts(plot={'invert_axes': False, 'invert_yaxis': True})


pointerx = hv.streams.PointerX(x=0).rename(x="t")
pointery = hv.streams.PointerY(y=0)

# BOUNDS: (left, bottom, top, right)
trench_idx = 15
img_series = frame_series_k1
bbox_ul, bbox_lr = get_trench_bbox(
    trench_points, trench_idx, *get_img_limits(img_series[0])
)
kymo = extract_kymograph(
    img_series, trench_points[0][trench_idx], trench_points[1][trench_idx]
)
kymo_img = hv.Image(kymo[::-1], bounds=(0, 0, kymo.shape[1], kymo.shape[0])).opts(
    plot={"width": 700, "height": 300, "invert_yaxis": True}
)
# kymo_img = hv.Raster(kymo).opts(plot={'width': 700, 'height': 300, 'yaxis': 'left'})
scrubber_line = hv.DynamicMap(scrubber_callback, streams=[pointerx, pointery])
cross_section_line = hv.DynamicMap(cross_section_callback, streams=[pointerx, pointery])
trench_thumbnail_img = hv.DynamicMap(
    partial(trench_thumbnail_callback, img_series), streams=[pointerx]
)  # .opts(plot={'invert_axes': True})
sharpness_plot = hv.Curve(sharpness, kdims=["x"], vdims=["s"]).opts(
    plot={"width": 700, "height": 100}
)
sharpness_scrubber = hv.DynamicMap(sharpness_scrubber_callback, streams=[pointerx])
# pointery.source = kymo_img
# SEE: http://holoviews.org/reference/elements/bokeh/Distribution.html
# kymo_img * scrubber_line << (trench_thumbnail_img * cross_section_line) << (sharpness_plot * sharpness_scrubber)

KymoOverlayStream = Stream.define("KymoOverlayStream", overlay_enabled=True)
kymo_overlay_stream = KymoOverlayStream()
kymo_overlay_button = widgets.ToggleButton(description="Overlay", value=True)


def update_kymo_overlay(change):
    kymo_overlay_stream.event(overlay_enabled=change["new"])


kymo_overlay_button.observe(update_kymo_overlay, names="value")


def get_thumbnail_overlay(t, overlay_enabled):
    # because of https://github.com/ioam/holoviews/issues/1388, overlay_enabled must be True initially
    x0 = trench_points[0][trench_idx]
    x1 = trench_points[1][trench_idx]
    trench_line = hv.Curve([x0 - bbox_ul, x1 - bbox_ul]).opts(
        style={"color": "white", "alpha": 0.5, "line_width": 1.5},
        plot={"yaxis": None, "shared_axes": True},
    )
    overlays = []
    if overlay_enabled:
        overlays.append(trench_line)
    else:
        pass  # overlays = []#line.opts(style={'alpha': 0.0, 'color': 'red'})
    return hv.Overlay(overlays)


thumbnail_overlay = hv.DynamicMap(
    get_thumbnail_overlay, streams=[pointerx, kymo_overlay_stream]
)  # .opts(plot={'invert_yaxis': True})

display(kymo_overlay_button)
(
    kymo_img * scrubber_line
    << (trench_thumbnail_img * thumbnail_overlay * cross_section_line)
    << (sharpness_plot * sharpness_scrubber)
)
# kymo_img * scrubber_line << (trench_thumbnail_img * cross_section_line) << (sharpness_plot * sharpness_scrubber)
# trench_thumbnail_img * thumbnail_overlay
# trench_thumbnail_img * thumbnail_overlay * cross_section_line

In [None]:
def trench_peaks_callback(t, y):
    t = int(t)
    overlays = [hv.Curve(kymo[:, t]).opts(plot={"width": 500})]
    overlays.extend(
        [
            hv.Points([(x, kymo[int(x), t])]).opts(style={"size": 6})
            for x in kymo_endpoints[t]
        ]
    )
    return hv.Overlay(overlays)


hv.DynamicMap(trench_peaks_callback, streams=[pointerx, pointery])

In [None]:
%%time
def find_kymograph_cell_endpoints(kymograph, thresh=0.2, min_dist=3):
    endpoints = []
    for t in range(kymograph.shape[1]):
        idxs = peakutils.indexes(kymograph[:, t], thres=thresh, min_dist=min_dist)
        xs = idxs
        # xs = peakutils.interpolate(np.arange(kymograph.shape[0]), kymograph[:,t], ind=idxs)
        endpoints.append(xs)
    return endpoints


kymo_endpoints = find_kymograph_cell_endpoints(kymo, thresh=0.2, min_dist=5)

In [None]:
for t in range(kymo.shape[1] // 5):
    plt.plot(kymo[220:, t])

# Image viewers

In [None]:
# big_image_viewer(root_group['quantized'])

In [None]:
def overlay_image_viewer(frames):
    frame_stream = FrameStream()
    slider_box = frame_browser(frames, frame_stream)
    channels_box, display_settings_stream = multichannel_selector(frames)

    def image_callback(t, v, channel_enabled, channel_colors):
        channel_imgs = [
            frames[v, c, t, :, :] for c in range(frames.shape[1]) if channel_enabled[c]
        ]
        img = composite_channels(channel_imgs, select(channel_colors, channel_enabled))
        viewer = hv.RGB(img[::-1], bounds=(0, 0, img.shape[1], img.shape[0]))
        return viewer

    image = hv.DynamicMap(
        image_callback, streams=[frame_stream, display_settings_stream]
    )
    image = regrid(image)
    image = image.opts(plot={"width": 500, "height": 500})
    output = widgets.Output()
    box = widgets.VBox([widgets.HBox([channels_box, slider_box]), output])
    display(box)
    with output:
        display(image)
    return None


overlay_image_viewer(frames_z)

# Old

In [None]:
# %%opts Image (cmap='viridis')
img_series = frame_series
kymo = extract_kymograph(
    img_series, trench_points[0][trench_idx], trench_points[1][trench_idx]
)
kymo_img = hv.Image(kymo, bounds=(0, 0, kymo.shape[1], kymo.shape[0])).opts(
    plot={"width": 700, "height": 300}
)
trench_thumbnail_img = hv.DynamicMap(
    partial(trench_thumbnail, img_series), streams=[pointerx]
)
(
    kymo_img.opts(plot={"invert_axes": True}) * scrubber_line
    + trench_thumbnail_img * cross_section_line
)

In [None]:
# handle out-of-bounds/negative t/x values
# focus quality score
# show focus quality as a bar above kymograph
# accurate horizontal line position on thumbnail
# draw cross-section line through thumbnail
# synchronize zoom/pan of multiple kymograph viewers
# 3-up thumbnail viewer with crosshairs/endpoint correspondences (links) synchronized with kymograph viewer
# clicking through a track advances 3-up by one frame
# if you press END it finishes track and rewinds to the first unfinished track
# automatic tracking: always identify bottom-most endpoint?
# can we work out all other correspondences from this? (look at t=14)