# Imports

In [None]:
import numpy as np
import pandas as pd
import zarr
import xarray as xr
import dask
from dask import delayed
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
import streamz
import holoviews as hv
import datashader as ds
from holoviews.operation.datashader import (
    aggregate,
    datashade,
    dynspread,
    shade,
    regrid,
)
from holoviews.operation import decimate
from holoviews.streams import Stream, param
from bokeh.models.tools import HoverTool
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import qgrid
from IPython.display import display, clear_output
import skimage
import skimage.morphology
import skimage.feature
import peakutils
import scipy.stats
import scipy.interpolate
import functools
import operator
from operator import getitem
from functools import partial, reduce
import itertools
from itertools import zip_longest, islice
from more_itertools import rstrip
from collections import Counter, defaultdict
from collections.abc import Mapping, Sequence
from bokeh.models import WheelZoomTool
from bokeh.io import push_notebook, show, output_notebook

# from bokeh.layouts import row
# from bokeh.plotting import figure
from tqdm import tnrange, tqdm, tqdm_notebook
from copy import copy, deepcopy
import random
import warnings

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from processing import *
from trench_detection import *
from trench_segmentation import *
from trench_segmentation.watershed import *
from util import *
from ui import *

In [None]:
%load_ext line_profiler
hv.notebook_extension("bokeh")
# renderer = hv.renderer('bokeh')
%matplotlib inline
tqdm.monitor_interval = 0

# Config

In [None]:
# warnings.filterwarnings('ignore', 'Conversion of the second argument of issubdtype')
warnings.filterwarnings("ignore", category=FutureWarning)
# warnings.filterwarnings('default', category=FutureWarning)

In [None]:
# store = zarr.DirectoryStore('/home/jqs1/scratch/fidelity/171018/171018.zarr')
# store = zarr.DirectoryStore('/home/jqs1/scratch/fidelity/171214/transcriptionerror_timelapse.zarr')
store = zarr.DirectoryStore(
    "/home/jqs1/scratch/fidelity/180227/txnerr_timelapse_bf.zarr"
)
# store = zarr.DirectoryStore('/home/jqs1/scratch/fidelity/180226/txnerr_timelapse.zarr')
root_group = zarr.open_group(store=store)

In [None]:
cluster = SLURMCluster(
    queue="short", walltime="3:00:00", threads=1, processes=1, local_directory="/tmp"
)
# interface='ib0'
client = Client(cluster)

In [None]:
cluster.start_workers(10)

In [None]:
%store -r trench_points_pos
%store -r trench_sharpness
%store -r trench_mean
%store -r trench_max

# Parallel trench detection

In [None]:
root_group.tree()

In [None]:
root_group["raw"]["0"][1, 30, 1].shape

In [None]:
getitem(root_group["raw"]["0"], (1, 30, 1)).shape

In [None]:
trench_points_pos = {pos: delayed(get_trenches)(delayed(getitem)(root_group['raw'][pos], (1, 30)) for pos in root_group['raw'].keys()}

In [None]:
trench_points_pos, diag_pos = map_with_diagnostics(
    get_trenches,
)

# Trench detection

In [None]:
diag_pos = tree()
trench_points_pos = {}

In [None]:
for pos in tnrange(100):
    res = fail_silently(
        lambda: get_trenches(
            root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos]
        )
    )
    if res:
        trench_points_pos[pos] = res

In [None]:
diag = tree()
_ = get_trenches(root_group["raw"][str(pos)][0, 1], diagnostics=diag)

In [None]:
display_plot_browser(diag)

In [None]:
diag_pos[0]["label_1"]["trench_ends"].keys()

In [None]:
diag_trenches = {}
for pos in diag_pos.keys():
    res = recursive_getattr(
        diag_pos, [pos, "label_1", "trench_ends", "image_with_trenches"]
    )
    if res is not None:
        diag_trenches[pos] = {"trenches": res}

In [None]:
frame_stream = FrameStream()
display_plot_browser(
    lambda t, v: diag_trenches[v] if v in diag_trenches else None, frame_stream
)

In [None]:
pos = 0
get_trenches(root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos])

In [None]:
img = root_group["raw"]["0"][1, 30]

In [None]:
display_plot_browser(diag_pos[0])

In [None]:
df = diagnostics_to_dataframe(diag_pos)

In [None]:
df2 = expand_diagnostics_by_label(df)

In [None]:
df3 = drop_constant_columns(df2)

In [None]:
qg = qshow(df3)

In [None]:
qg

In [None]:
from holoviews.streams import Pipe
from bokeh.models import HoverTool

In [None]:
df3

In [None]:
hv.Points(df3.reset_index(), kdims=[f1, f2], vdims=["pos", "label"])

In [None]:
f1 = "trench_rotation.hough_2.angle"
f2 = "trench_anchors.periodogram_2.period"

hover = HoverTool(
    tooltips=[
        ("(pos,label)", "(@pos, @label)"),
        (f1, "@{{{}}}".format(f1)),
        (f2, "@{{{}}}".format(f2)),
    ]
)

hv.Points(df3.reset_index(), kdims=[f1, f2], vdims=["pos", "label"]).opts(
    plot={"size": 50, "tools": [hover]}, style={"size": 5}
)

In [None]:
def 

qgrid_stream = Pipe(data=df)
dmap = hv.DynamicMap(plot_callback, streams=[qgrid_stream])
qg.observe(lambda x: qgrid_stream.send(data=x, names=['_df'])

# Trench detection debugging

In [None]:
f1 = "trench_rotation.hough_2.angle"
f2 = "trench_anchors.periodogram_2.period"

In [None]:
good_detection = (df3[f1] > -1) & (df3[f2] > 20)
qshow(df3[~good_detection])

In [None]:
hover = HoverTool(
    tooltips=[
        ("(pos,label)", "(@pos, @label)"),
        (f1, "@{{{}}}".format(f1)),
        (f2, "@{{{}}}".format(f2)),
    ]
)

hv.Points(df3.reset_index(), kdims=[f1, f2], vdims=["pos", "label"]).opts(
    plot={"size": 50, "tools": [hover]}, style={"size": 5}
)

In [None]:
%%time
diag = tree()
get_trenches(root_group["raw"][str(7)][0, 30], diagnostics=diag)

In [None]:
display_plot_browser(diag)

In [None]:
frame_stream = FrameStream()
display(frame_browser(root_group["quantized"], frame_stream))
display_plot_browser(lambda t, v: diag_pos[v] if v in diag_pos else None, frame_stream)

In [None]:
big_image_viewer(root_group["raw"])

In [None]:
import cloudpickle

In [None]:
import resource

resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024

In [None]:
resource?

In [None]:
resource.getrusage(resource.RUSAGE_SELF)

In [None]:
import objgraph

In [None]:
objgraph.most_common_types()

In [None]:
len(b)

In [None]:
import gc

gc.collect()

In [None]:
from pympler import asizeof

In [None]:
diag_pos[0]["label_1"]["trench_ends"].keys()

In [None]:
c = cloudpickle.dumps(diag_pos[0]["label_1"]["trench_ends"])

In [None]:
with open("foo", "wb") as f:
    cloudpickle.dump(diag_pos, f)

In [None]:
len(c)

In [None]:
print(asizeof.asized(diag_pos[0], detail=5).format())

In [None]:
%store -r

In [None]:
%store trench_points_pos
%store diag_pos

In [None]:
def visualize_trench(img1, img2, img3, traces=None):
    # img_k1, img_k2 = hessian_eigenvalues(img)
    #     pointery = hv.streams.PointerY(y=0)
    #     def scrubber_callback(y):
    #         y = int(y)
    #         h_line = hv.HLine(y + 0.5)(style={'color': 'black', 'alpha': 0.9, 'line_width': 2})
    #         return h_line
    #     scrubber_dmap = hv.DynamicMap(scrubber_callback, streams=[pointery])
    # plot = RevImage(img) * scrubber_dmap + RevImage(img_k1) * scrubber_dmap + RevImage(img_k2) * scrubber_dmap
    plot = RevImage(img1) + RevImage(img2) + RevImage(img3)
    if traces is not None:
        plot += hv.Overlay.from_values(
            [
                hv.Curve(trace / abs(trace).max(), kdims=["y"], vdims=["i" + str(i)])
                for i, trace in enumerate(traces)
            ]
        )
    plot = plot.opts(
        plot={
            "Image": {"height": thumbs.shape[1] * 2, "width": thumbs.shape[2] * 5},
            "Curve": {
                "invert_axes": True,
                "height": thumbs.shape[1] * 2,
                "width": thumbs.shape[2] * 5,
                "invert_yaxis": True,
            },
            "Layout": {"normalize": False},
        }
    )
    return plot


# visualize_trench(a)#, [a.sum(axis=1)])

In [None]:
def trench_side_masks(mask, dilation=3):
    left = np.logical_or.accumulate(mask[:, ::-1], axis=1)[:, ::-1]
    right = np.logical_or.accumulate(mask, axis=1)
    outline = repeat_apply(skimage.morphology.binary_dilation, dilation)(mask)
    return (outline & ~right, outline & ~left)

In [None]:
pos = str(0)
n_channels, n_timepoints = root_group["raw"][pos].shape[:2]
channel = 0

In [None]:
diag_td = tree()
trench_points = get_trenches(root_group["raw"][pos][0, 30], diagnostics=diag_td)

In [None]:
# display_plot_browser(diag_td);

In [None]:
sharpness = [
    image_sharpness(root_group["raw"][pos][0, t]) for t in tnrange(n_timepoints)
]

In [None]:
trench_idx = 18
trench_set_idx = 1

In [None]:
x_lim, y_lim = get_img_limits(root_group["raw"][pos].shape[2:])
ul, lr = get_trench_bbox(trench_points[trench_set_idx], trench_idx, x_lim, y_lim)
thumbs = root_group["raw"][pos][0, :, ul[1] : lr[1], ul[0] : lr[0]]

In [None]:
thumbs_k1 = map_ndarray(lambda img: hessian_eigenvalues(img)[0], thumbs)
thumbs_k2 = map_ndarray(lambda img: hessian_eigenvalues(img)[1], thumbs)
thumbs_masks = map_ndarray(
    lambda img: img > skimage.filters.threshold_otsu(img), thumbs
)

In [None]:
traces_min_left = []
traces_min_right = []
traces_max_left = []
traces_max_right = []
for t in tnrange(thumbs.shape[0]):
    left_mask, right_mask = trench_side_masks(thumbs_masks[t])
    traces_min_left.append((left_mask * thumbs_k2[t]).min(axis=1))
    traces_min_right.append((right_mask * thumbs_k2[t]).min(axis=1))
    traces_max_left.append((left_mask * thumbs_k2[t]).max(axis=1))
    traces_max_right.append((right_mask * thumbs_k2[t]).max(axis=1))

In [None]:
t_max = 128
hv.HoloMap(
    {
        t: visualize_trench(
            thumbs[t],
            trench_side_masks(thumbs_masks[t])[0] * thumbs_k2[t],
            thumbs_k2[t],
            [
                traces_min_left[t],
                traces_min_right[t],
                traces_max_left[t],
                traces_max_right[t],
            ],
        )
        for t in range(t_max)
    }
)

## Prototyping

In [None]:
# a = thumbs_k1[0]
a = thumbs[61]  # 31, 36, 61
a_k1, a_k2 = hessian_eigenvalues(a)

In [None]:
a.shape

In [None]:
threshold = skimage.filters.threshold_otsu(a)
a_thresh = a > threshold
plt.imshow(a_thresh)

In [None]:
plt.plot(a.sum(axis=1))

In [None]:
a_k1d = np.diff(a_k1, axis=0)

In [None]:
visualize_trench(
    a, a_k1, a_k1d, [a_k1[:, 25], a.sum(axis=1), a_k1d[:, 20:30].sum(axis=1)]
)

In [None]:
left_mask, right_mask = trench_side_masks(a_thresh)
left_trace_min = (left_mask * a_k2).min(axis=1)
right_trace_min = (right_mask * a_k2).min(axis=1)
left_trace_max = (left_mask * a_k2).max(axis=1)
right_trace_max = (right_mask * a_k2).max(axis=1)

In [None]:
visualize_trench(
    a,
    trench_side_masks(a_thresh)[0] * a_k2,
    a_k2,
    [left_trace_min, right_trace_min, left_trace_max, right_trace_max],
)

In [None]:
visualize_trench(a, a_k2, a_k1d, [a_k2[:, 22], a_k2[:, 25], a_k2[:, 28]])

In [None]:
plt.figure(figsize=(12, 12))
for i in range(20, 31):
    plt.plot(a_k2[:, i], label=str(i))
plt.plot(a_k2[:, :25].max(axis=1), label="max r", linestyle="--", lw=2)
plt.plot(a_k2[:, 25:].max(axis=1), label="max l", linestyle="--", lw=2)
plt.legend()

In [None]:
plt.plot(a_thresh.sum(axis=1))

In [None]:
plt.plot(a[135])

In [None]:
plt.plot(a[150])

In [None]:
plt.imshow(scipy.signal.correlate(a, a, "same"))
# plt.plot(np.correlate(, axis=0))

In [None]:
trench_points[trench_set_idx][0].shape

In [None]:
x0, x1 = (
    trench_points[trench_set_idx][0][trench_idx],
    trench_points[trench_set_idx][1][trench_idx],
)
xs, ys = coords_along(x0, x1)
profiles = np.array([thumbs_k1[t, ys - ul[1], xs - ul[0]] for t in range(n_timepoints)])

In [None]:
RevImage(thumbs[0]) * hv.Points([x0 - ul, x1 - ul]).opts(
    plot={"size": 30, "color": "green"}
)

In [None]:
hv.help(hv.Curve)

In [None]:
%%opts Layout [normalize=False]
t_max = 2


def plot_trench(t):
    p = (
        RevImage(thumbs[t])
        + RevImage(thumbs_k1[t])
        + hv.Curve(profiles[t], kdims=["y"], vdims=["i"]).opts(plot={"swap_axes": True})
    ).opts(
        plot={
            "Image": {
                "height": thumbs.shape[1] * 2,
                "width": thumbs.shape[2] * 3,
                "aspect": 0.2,
            }
        }
    )
    return p


# m = hv.HoloMap({t: plot_trench(t) for t in range(t_max)})
# m
plot_trench(0)

## Segmentation algorithm testing

In [None]:
from skimage.segmentation import morphological_geodesic_active_contour

In [None]:
morphological_geodesic_active_contour?

In [None]:
f = thumbs[0]
image = f / f.max()

# Initial level set
init_ls = np.zeros(image.shape, dtype=np.int8)
init_ls[10:-10, 10:-10] = 1
ls = morphological_geodesic_active_contour(
    image, 600, smoothing=2, threshold="auto", init_level_set=init_ls
)
plt.imshow(ls)
# List with intermediate results for plotting the evolution
# evolution = []
# callback = store_evolution_in(evolution)
# ls = morphological_geodesic_active_contour(gimage, 230, init_ls,
#                                           smoothing=1, balloon=-1,
#                                           threshold=0.69,
#                                           iter_callback=callback)

In [None]:
morphological_geodesic_active_contour?

## Kymo viewer

In [None]:
%%opts Image (cmap='viridis')
def scrubber_callback(t, y):
    t = int(t)
    v_line = hv.VLine(t + 0.5)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    t_label = hv.Text(t + kymo.shape[1] / 20, kymo.shape[0] / 20, "t={:}".format(t))(
        style={"color": "white", "text_alpha": 0.8}
    )
    h_line = hv.HLine(y)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    return v_line * t_label * h_line


def cross_section_callback(t, y):
    y = int(y)
    h_line = hv.HLine(y)(style={"color": "white", "alpha": 0.3, "line_width": 2})
    return h_line


def sharpness_scrubber_callback(t):
    t = int(t)
    v_line = hv.VLine(t + 0.5)(style={"color": "black", "alpha": 0.3, "line_width": 2})
    return v_line


def trench_thumbnail_callback(img_series, t):
    t = int(t)
    img = get_trench_thumbnail(img_series[t], trench_points, trench_idx)
    # TODO: don't know why calling hv.Image.opts(plot={'invert_axes': True}) doesn't work
    # also switched HLine to VLine in cross_section, above
    thumb = hv.Image(
        img[::-1], bounds=(0, 0, img.shape[1], img.shape[0]), kdims=["x2", "y"]
    )
    # return thumb
    # return thumb.opts(plot={'invert_axes': False, 'width': max_thumbnail_width})#, 'xaxis': None, 'yaxis': None})
    return thumb.opts(
        plot={"invert_axes": False, "invert_yaxis": True, "width": 200}
    )  # , 'xaxis': None, 'yaxis': None})
    # return thumb.opts(plot={'invert_axes': False, 'invert_yaxis': True})


pointerx = hv.streams.PointerX(x=0).rename(x="t")
pointery = hv.streams.PointerY(y=0)

# BOUNDS: (left, bottom, top, right)
trench_idx = 15
img_series = frame_series_k1
bbox_ul, bbox_lr = get_trench_bbox(
    trench_points, trench_idx, *get_img_limits(img_series[0])
)
kymo = extract_kymograph(
    img_series, trench_points[0][trench_idx], trench_points[1][trench_idx]
)
kymo_img = hv.Image(kymo[::-1], bounds=(0, 0, kymo.shape[1], kymo.shape[0])).opts(
    plot={"width": 700, "height": 300, "invert_yaxis": True}
)
# kymo_img = hv.Raster(kymo).opts(plot={'width': 700, 'height': 300, 'yaxis': 'left'})
scrubber_line = hv.DynamicMap(scrubber_callback, streams=[pointerx, pointery])
cross_section_line = hv.DynamicMap(cross_section_callback, streams=[pointerx, pointery])
trench_thumbnail_img = hv.DynamicMap(
    partial(trench_thumbnail_callback, img_series), streams=[pointerx]
)  # .opts(plot={'invert_axes': True})
sharpness_plot = hv.Curve(sharpness, kdims=["x"], vdims=["s"]).opts(
    plot={"width": 700, "height": 100}
)
sharpness_scrubber = hv.DynamicMap(sharpness_scrubber_callback, streams=[pointerx])
# pointery.source = kymo_img
# SEE: http://holoviews.org/reference/elements/bokeh/Distribution.html
# kymo_img * scrubber_line << (trench_thumbnail_img * cross_section_line) << (sharpness_plot * sharpness_scrubber)

KymoOverlayStream = Stream.define("KymoOverlayStream", overlay_enabled=True)
kymo_overlay_stream = KymoOverlayStream()
kymo_overlay_button = widgets.ToggleButton(description="Overlay", value=True)


def update_kymo_overlay(change):
    kymo_overlay_stream.event(overlay_enabled=change["new"])


kymo_overlay_button.observe(update_kymo_overlay, names="value")


def get_thumbnail_overlay(t, overlay_enabled):
    # because of https://github.com/ioam/holoviews/issues/1388, overlay_enabled must be True initially
    x0 = trench_points[0][trench_idx]
    x1 = trench_points[1][trench_idx]
    trench_line = hv.Curve([x0 - bbox_ul, x1 - bbox_ul]).opts(
        style={"color": "white", "alpha": 0.5, "line_width": 1.5},
        plot={"yaxis": None, "shared_axes": True},
    )
    overlays = []
    if overlay_enabled:
        overlays.append(trench_line)
    else:
        pass  # overlays = []#line.opts(style={'alpha': 0.0, 'color': 'red'})
    return hv.Overlay(overlays)


thumbnail_overlay = hv.DynamicMap(
    get_thumbnail_overlay, streams=[pointerx, kymo_overlay_stream]
)  # .opts(plot={'invert_yaxis': True})

display(kymo_overlay_button)
(
    kymo_img * scrubber_line
    << (trench_thumbnail_img * thumbnail_overlay * cross_section_line)
    << (sharpness_plot * sharpness_scrubber)
)
# kymo_img * scrubber_line << (trench_thumbnail_img * cross_section_line) << (sharpness_plot * sharpness_scrubber)
# trench_thumbnail_img * thumbnail_overlay
# trench_thumbnail_img * thumbnail_overlay * cross_section_line

In [None]:
def trench_peaks_callback(t, y):
    t = int(t)
    overlays = [hv.Curve(kymo[:, t]).opts(plot={"width": 500})]
    overlays.extend(
        [
            hv.Points([(x, kymo[int(x), t])]).opts(style={"size": 6})
            for x in kymo_endpoints[t]
        ]
    )
    return hv.Overlay(overlays)


hv.DynamicMap(trench_peaks_callback, streams=[pointerx, pointery])

In [None]:
%%time
def find_kymograph_cell_endpoints(kymograph, thresh=0.2, min_dist=3):
    endpoints = []
    for t in range(kymograph.shape[1]):
        idxs = peakutils.indexes(kymograph[:, t], thres=thresh, min_dist=min_dist)
        xs = idxs
        # xs = peakutils.interpolate(np.arange(kymograph.shape[0]), kymograph[:,t], ind=idxs)
        endpoints.append(xs)
    return endpoints


kymo_endpoints = find_kymograph_cell_endpoints(kymo, thresh=0.2, min_dist=5)

In [None]:
for t in range(kymo.shape[1] // 5):
    plt.plot(kymo[220:, t])

# Trench sharpness

In [None]:
def f(img_stack):
    return pd.Series([image_sharpness(img_stack[t]) for t in range(img_stack.shape[0])])


trench_sharpness = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(10),
)

In [None]:
def f(img_stack):
    return pd.Series(img_stack.mean(axis=(1, 2)))


trench_mean = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(10),
)


def f(img_stack):
    return pd.Series(img_stack.max(axis=(1, 2)))


trench_max = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(10),
)

In [None]:
%%output size=250
dat = trench_sharpness / trench_max
hover = HoverTool(
    tooltips=[
        ("(t,f)", "(@time, @value)"),
        ("label", "@label"),
    ]
)
hv.Overlay.from_values(
    [
        hv.Curve(
            pd.DataFrame(
                {"time": d.index, "value": d.values, "label": "{}.{}.{}".format(*col)}
            ),
            vdims=["value", "label"],
            kdims=["time"],
        ).opts(plot={"tools": [hover]})
        for col, d in random.sample(list(dat.iteritems()), 50)
    ]
)

In [None]:
pos = 7
show_trench_movie(
    root_group["raw"][pos],
    trench_points_pos[pos],
    2,
    54,
    channels=[1],
    time_slice=slice(None, None, 3),
)

# Trench segmentation

## Debugging

In [None]:
# pos, trench_set_idx, trench_idx = 6, 2, 54
pos, trench_set_idx, trench_idx = 4, 1, 41
trench_points = trench_points_pos[pos]

In [None]:
x_lim, y_lim = get_img_limits(root_group["raw"][pos].shape[2:])
ul, lr = get_trench_bbox(trench_points[trench_set_idx], trench_idx, x_lim, y_lim)
thumbs = root_group["raw"][pos][1, :, ul[1] : lr[1], ul[0] : lr[0]]

In [None]:
def optimize_segmentation_threshold(segmentations, diagnostics=None):
    thresholds = []
    metrics = []
    for threshold, segmentation in segmentations.items():
        regionprops = skimage.measure.regionprops(segmentation)
        metric = np.std([region.area for region in regionprops])
        thresholds.append(threshold)
        metrics.append(metric)
    if diagnostics is not None:
        diagnostics["segmentations"] = hv.HoloMap(
            {t: RevImage(seg) for t, seg in segmentations.items()}
        )
        diagnostics["metric"] = hv.Curve((thresholds, metrics))
    return 0

In [None]:
%%time
diag_seg_pos = tree()
segs = [
    segment_trench(thumbs[t], diagnostics=diag_seg_pos[t])
    for t in range(thumbs.shape[0])
]

In [None]:
t = 71
segment_trench(thumbs[t], diagnostics=diag_seg_pos[t])

In [None]:
hv.HoloMap({t: RevImage(seg.T) for t, seg in enumerate(segs)})

In [None]:
display_plot_browser(diag_seg_pos[1])

## Full analysis

In [None]:
def f(img_stack):
    ary = np.stack(
        [
            segment_trench(img_stack[t], diagnostics=None)
            for t in range(img_stack.shape[0])
        ],
        axis=0,
    )
    ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_seg_masks = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

In [None]:
%store trench_seg_masks

In [None]:
from trench_segmentation.watershed import _trench_img, segment_trench

In [None]:
seg_masks = trench_seg_masks[0][1][30]
hv.HoloMap(
    {
        t: _trench_img(seg_masks[t]).options(height=150, width=600)
        for t in range(seg_masks.shape[0])
    }
)

# Trench all-up kymograph

In [None]:
res = np.concatenate(thumbs, axis=1)
print(thumbs.shape, res.shape)
plt.imshow(res)

In [None]:
def f(img_stack):
    ary = np.concatenate(img_stack, axis=1)
    # ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_kymographs_rfp = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)
trench_kymographs_yfp = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

# Trench fluorescence

In [None]:
def f(img_stack):
    return pd.Series([0])  # pd.Series(np.percentile(img_stack, 90, axis=(1,2)))


pos = 0
trench_traces2 = trenchwise_map(
    root_group["raw"][pos],
    trench_points_pos[pos],
    f,
    channel_slice=2,
    time_slice=slice(50),
)

In [None]:
# FROM: http://holoviews.org/reference/elements/bokeh/QuadMesh.html
def hover_image(img, hover):
    return img * hv.QuadMesh(img).opts(
        plot={"tools": [hover]},
        style={"alpha": 0, "hover_line_alpha": 1, "hover_line_color": "black"},
    )


def show_trench_movie(
    img_stack,
    trench_points,
    trench_set_idx,
    trench_idx,
    channels=None,
    time_slice=slice(None, None, None),
):
    if channels is None:
        channels = range(img_stack.shape[0])
    hover = HoverTool(tooltips=[("(x,y)", "(@x{0.0}, @y{0.0})"), ("intensity", "@z")])

    def func(img_stack):
        height, width = img_stack.shape[2:]
        return hv.HoloMap(
            {
                t: hv.Layout.from_values(
                    [
                        hover_image(
                            RevImage(
                                img_stack[channel, t],
                                kdims=["x", "y"],
                                vdims=["intensity"],
                            ),
                            hover,
                        )
                        for channel in channels
                    ]
                ).opts(
                    plot={
                        "Image": {"height": height * 2, "width": width * 5},
                        "Layout": {"normalize": False},
                    }
                )
                for t in range(img_stack.shape[1])
            }
        )

    return trenchwise_apply(
        img_stack,
        trench_points[trench_set_idx],
        trench_idx,
        func,
        time_slice=time_slice,
    )

In [None]:
def f(img_stack):
    return pd.Series(np.percentile(img_stack, 95, axis=(1, 2)))
    # return pd.Series(np.max(img_stack, axis=(1,2)))


trench_traces_all = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(100),
)

In [None]:
%%output size=300
traces = bright_traces
hover = HoverTool(
    tooltips=[
        ("(t,f)", "(@time, @fluorescence)"),
        ("label", "@label"),
    ]
)
hv.Overlay.from_values(
    [
        hv.Curve(
            pd.DataFrame(
                {
                    "time": d.index,
                    "fluorescence": d.values,
                    "label": "{}.{}.{}".format(*col),
                }
            ),
            vdims=["fluorescence", "label"],
            kdims=["time"],
        ).opts(plot={"tools": [hover]})
        for col, d in islice(traces.iteritems(), 100)
    ]
)

In [None]:
pos = 0
show_trench_movie(
    root_group["raw"][pos],
    trench_points_pos[pos],
    2,
    54,
    channels=[0, 1, 2],
    time_slice=slice(None),
)