# Imports

In [None]:
import asyncio
import os
import traceback
import warnings
from collections import defaultdict, namedtuple
from collections.abc import Mapping, Sequence
from functools import partial
from glob import glob
from importlib import reload
from numbers import Number
from operator import getitem

import cachetools
import dask
import distributed
import holoviews as hv
import hvplot.pandas
import ipywidgets as widgets
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import param
import parambokeh
import pyarrow as pa
import pyarrow.feather as feather
import pyarrow.parquet as pq
import qgrid
import scipy
import skimage.morphology
import streamz
import streamz.dataframe as sdf
import zarr
from bokeh.models.tools import HoverTool, TapTool
from cytoolz import *
from dask import delayed
from dask_jobqueue import SLURMCluster
from deepmerge import merge_or_raise
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from holoviews.streams import Selection1D, Stream, param
from IPython.display import Video
from tqdm import tnrange, tqdm, tqdm_notebook
from traitlets import All

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from processing import *
# from trench_detection import *
# from trench_segmentation import *
# from trench_segmentation.watershed import *
# from util import *
# from ui import *
import common
import data_io
import diagnostics
import geometry
import image
import metadata
import processing
import trench_detection
import trench_detection.core
import trench_detection.hough
import trench_segmentation.watershed
import ui
import util
import workflow

In [None]:
%load_ext line_profiler
%load_ext snakeviz
hv.extension("bokeh", "matplotlib")
%matplotlib inline
tqdm.monitor_interval = 0
asyncio.get_event_loop().set_debug(False)
import logging

logging.basicConfig(level=logging.DEBUG)
import warnings

warnings.simplefilter("ignore")

# Loading data

In [None]:
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2']#, '/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2',
#                 '/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2', '/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
# nd2_filenames = ['/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC001.nd2', '/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC002.nd2']
# nd2_filenames = glob('/n/scratch2/jqs1/fidelity/all/180405*.nd2') + glob('/n/scratch2/jqs1/fidelity/all/TrErr*.nd2')
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr_loweronly.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr_loweronly_fast.nd2'] + glob('/n/scratch2/jqs1/fidelity/all/TrErr*.nd2')
nd2_filenames = ["/n/scratch2/jqs1/fidelity/all/TrErr.nd2"]
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180928_txnerr_bigsnake.nd2']

In [None]:
all_frames, metadata, parsed_metadata = workflow.get_nd2_frame_list(nd2_filenames)
image_limits = workflow.get_filename_image_limits(metadata)

# Config

In [None]:
dask.config.config["distributed"]["worker"]["memory"] = {
    "target": 0.9,
    "spill": None,
    "pause": None,
    "terminate": 0.95,
}

In [None]:
# dask.config.config['distributed']['worker']['profile'] = {'interval': '10s', 'cycle': '10s'}
# {'interval': '10ms', 'cycle': '1000ms'}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # job_extra=['--exclude=compute-e-16-181,compute-e-16-186'],
    # interface='ib0',
    memory="4GB",  # TODO!!!
    local_directory="/tmp",
    cores=1,
    processes=1,
    # diagnostics_port=('127.0.0.1', 8787),
    env_extra=[
        'export PYTHONPATH="/home/jqs1/projects/matriarch"',
        #'export PYTHONTRACEMALLOC=25',
        #'export MALLOC_CONF=prof:true,prof_leak:true,lg_prof_interval:31,prof_final:true',
        'export LD_PRELOAD="/home/jqs1/lib/libjemalloc.so.2"',
    ],
)
client = Client(cluster)  # , direct_to_workers=True)

In [None]:
cluster.scale(40)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
cluster.stop_jobs(cluster.running_jobs.keys())

In [None]:
cluster.scheduler.stop_services()
cluster.scheduler.stop()

In [None]:
client.restart()

# Reload

In [None]:
def do_reload():
    from importlib import reload

    import diagnostics
    import image
    import trench_detection
    import util
    import workflow

    # reload(util)
    # reload(trench_detection.hough)
    # reload(diagnostics)
    reload(workflow)
    # reload(image)


client.run(do_reload)
do_reload()

# Trench detection

In [None]:
FrameStream = ui.MultiIndexStream.define("FrameStream", all_frames.index)
frame_stream = FrameStream()
box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
%%time
frame = workflow.get_nd2_frame(**frame_stream.contents)
find_trenches_diag = diagnostics.wrap_diagnostics(
    trench_detection.find_trenches, ignore_exceptions=False, pandas=True
)
trench_points, trench_diag, trench_err = find_trenches_diag(frame)

# Data reduction

In [None]:
selected_frames = all_frames.loc[IDX["/n/scratch2/jqs1/fidelity/all/TrErr.nd2", :1], :]

## New trench detection+segmentation+analysis

In [None]:
pixelwise_funcs = {
    "mean": np.mean,
    "min": np.min,
    "max": np.max,
    ("p0.3", "p0.5", "p0.7", "p0.9", "p0.95"): partial(
        np.percentile, q=(30, 50, 70, 90, 95)
    ),
}
trenchwise_funcs = {"sharpness": image.sharpness, **pixelwise_funcs}


def measurement_func(label_image, intensity_image):
    eroded_label_image = (
        util.repeat_apply(skimage.morphology.binary_erosion, 2)(label_image != 0)
        * label_image
    )
    if intensity_image is None:
        minlength = label_image.max() + 1
        mask_labelwise_df = pd.DataFrame(
            {
                ("noerode", "size"): np.bincount(label_image.flat, minlength=minlength),
                ("erode2", "size"): np.bincount(
                    eroded_label_image.flat, minlength=minlength
                ),
            }
        )
        mask_labelwise_df.index.name = "label"
        return dict(mask_labelwise=mask_labelwise_df)
    labelwise = {
        "noerode": workflow.map_frame_over_labels(
            pixelwise_funcs, label_image, intensity_image
        ),
        "erode2": workflow.map_frame_over_labels(
            pixelwise_funcs, eroded_label_image, intensity_image
        ),
    }
    trenchwise_df = workflow.map_frame(trenchwise_funcs, intensity_image)
    labelwise_df = pd.concat(labelwise, axis=1)
    return dict(trenchwise=trenchwise_df, labelwise=labelwise_df)

In [None]:
def _measure(
    trenches,
    frames,
    measurement_func,
    segmentation_channel="MCHERRY",
    measure_channels=None,
    segmentation_func=trench_segmentation.watershed.segment_trench,
    include_frame=True,
    frame_bits=8,
    frame_downsample=4,
    filename=None,
    position=None,
):
    frame_transformation = compose(
        processing.zarrify,
        partial(image.quantize, bits=frame_bits),
        partial(image.downsample, factor=frame_downsample),
    )
    trench_crops = processing._get_trench_crops(
        trenches,
        frames,
        include_frame=include_frame,
        frame_transformation=frame_transformation,
        filename=filename,
        position=position,
    )
    flattened_crops = {
        (*k[0], *k[1:]): v
        for k, v in util.flatten_dict(trench_crops).items()
        if k[0] != "_frame"
    }
    segmentation_masks = {}
    measurements = {}
    # segment
    for (trench_set, trench_idx, channel, t), crop in flattened_crops.items():
        if measure_channels is not None and channel not in measure_channels:
            continue
        segmentation_key = (trench_set, trench_idx, segmentation_channel, t)
        segmentation_mask = segmentation_masks.get(segmentation_key, None)
        if segmentation_mask is None:
            segmentation_masks[segmentation_key] = segmentation_func(
                flattened_crops[segmentation_key]
            )
            # measure mask
            measurements[("mask", (trench_set, trench_idx, t))] = measurement_func(
                segmentation_masks[segmentation_key], None
            )
        # measure
        measurements[(channel, (trench_set, trench_idx, t))] = measurement_func(
            segmentation_masks[segmentation_key], crop
        )
    measurement_dfs = util.map_dict_levels(lambda k: (k[1], k[0], *k[2:]), measurements)
    for name, dfs in measurement_dfs.items():
        dfs = util.unflatten_dict(dfs)
        if isinstance(util.get_one(dfs, level=2), pd.Series):
            df = pd.concat(
                {
                    channel: pd.concat(channel_dfs, axis=1).T
                    for channel, channel_dfs in dfs.items()
                },
                axis=1,
            )
        else:
            df = pd.concat(
                {
                    channel: pd.concat(channel_dfs, axis=0)
                    for channel, channel_dfs in dfs.items()
                },
                axis=1,
            )
        df.index.names = ["trench_set", "trench", "t", *df.index.names[3:]]
        measurement_dfs[name] = df
    images = dict(
        raw=trench_crops, segmentation=util.unflatten_dict(segmentation_masks)
    )
    return dict(measurements=measurement_dfs, images=images)


measure = processing.iterate_over_groupby(["filename", "position"])(_measure)

In [None]:
def filename_func(extension=None, kind=None, name=None, filename=None, position=None):
    components = [s for s in ("", name, extension) if s is not None]
    return os.path.join(
        f"{filename}.{kind}", "pos{:d}".format(position) + ".".join(components)
    )

In [None]:
client = Client(n_workers=1)

In [None]:
frame = workflow.get_nd2_frame(*all_frames.index[0])

In [None]:
trench_info = find_trenches_diag(frame)

In [None]:
df = trench_info_to_dataframe(trench_info)

In [None]:
df2 = workflow._get_trench_bboxes_dataframe(df, *util.get_one(image_limits))

In [None]:
def trench_diag_to_dataframe(trench_diag, sep="."):
    df = diagnostics.expand_diagnostics_by_label(trench_diag.to_frame().T)
    df.index = df.index.droplevel(0)
    df.index.names = [*df.index.names[:-1], "trench_set"]
    return df


def trench_info_to_dataframe(trench_info, x_lim, y_lim):
    trench_points, trench_diag, trench_err = trench_info
    trench_diag = trench_diag_to_dataframe(trench_info[1])
    # FROM: https://stackoverflow.com/questions/14744068/prepend-a-level-to-a-pandas-multiindex
    trench_diag = pd.concat([trench_diag], axis=1, keys=["diag"])
    trenches = pd.concat(
        [trench_points, util.multi_join(trench_info[0].index, trench_diag)], axis=1
    )
    trench_bboxes = workflow._get_trench_bboxes_dataframe(trenches, x_lim, y_lim)
    trenches = pd.concat([trenches, trench_bboxes], axis=1)
    return trenches


find_trenches_diag = diagnostics.wrap_diagnostics(
    trench_detection.find_trenches, ignore_exceptions=True, pandas=True
)

_measure_and_write = compose(
    partial(
        processing.write_images_and_measurements,
        filename_func=filename_func,
        dataframe_format="parquet",
    ),
    measure,
)

# def filter_positions(trench_diag):
# return (trench_diag['find_trench_lines.hough_2.peak_func.pitch'] - 24).abs() > 1
# bad_pitch = ((trench_diag['find_trench_lines.hough_2.peak_func.pitch'] - 24).abs() > 1)
# trench_points_good = trench_points[~util.multi_join(trench_points.index, bad_pitch)]
# return trench_points # TODO: can just filter based on trench_diag series, no need for df


def filter_trenches(trenches):
    return (trench_diag["find_trench_lines.hough_2.peak_func.pitch"] - 24).abs() > 1
    return trenches[trenches[("info", "hough_value")] > 90]

In [None]:
trenches_futures = {}
trench_errs = {}  # TODO!!!
analysis_futures = {}

for key, frames in selected_frames.groupby(["filename", "position"]):
    filename, posiiton = key
    x_lim, y_lim = image_limits[filename]
    frame_to_segment = frames.loc[IDX[:, :, ["MCHERRY"], 0], :]
    frame_future = client.submit(workflow.get_nd2_frame, *frame_to_segment.index[0])
    trenches_future = client.submit(
        compose(
            partial(trench_info_to_dataframe, x_lim=x_lim, y_lim=y_lim),
            find_trenches_diag,
        ),
        frame_future,
    )
    # filter trenches, send good trenches to _measure_and_write
    # fire-and-forget _measure_and_write
    # write trenches to disk in per-filename parquet
    analysis_futures.append(
        client.submit(compose(util.return_none, _measure_and_write), trenches_future)
    )

# save trench_diags
# save trenches

## Streaming gather

In [None]:
%%time
ac = distributed.as_completed([], with_results=False, loop=client.loop)

new_futures_stream = streamz.Stream()
finished_futures_stream = streamz.Stream(asynchronous=True, loop=client.loop)

stream_sinks = {}
stream_writers = {}
output_filename = "/n/scratch2/jqs1/fidelity/all/output/analysis_full_stream11_{}.arrow"

new_futures_stream.sink(lambda x: ac.add(x))

errored_futures = set()
finished_futures_stream.filter(lambda x: x.status == "error").sink(
    lambda x: errored_futures.add(x)
)


def timeout_func(futures):
    ac.update(futures)


successful_futures_stream = finished_futures_stream.filter(
    lambda x: x.status == "finished"
)
# batched_futures_stream = successful_futures_stream.rate_limit(0.0004).timed_window(1)
# gathered_futures_stream = streamz.buffer(batched_futures_stream, 10).gather_and_cancel(client=client, cancel=True)
# batched_futures_stream = successful_futures_stream.timed_window(1)
batched_futures_stream = successful_futures_stream.rate_limit(0.01).timed_window(5)
buffered_futures_stream = (
    batched_futures_stream  # streamz.buffer(batched_futures_stream, 10)
)

cancelled_futures = set()
gathered_futures_stream = buffered_futures_stream.gather_and_cancel(
    client=client,
    gather=True,
    cancel=True,
    timeout=4,
    timeout_func=timeout_func,
    success_func=cancelled_futures.update,
)
# gathered_futures_stream.flatten().sink(partial(workflow.sink_to_arrow, sinks=stream_sinks, writers=stream_writers, output_func=lambda i: pa.OSFile(output_filename.format(i), 'w')))
write_failures = []
flattened_futures_stream = gathered_futures_stream.flatten()
writer_stream = (
    flattened_futures_stream  # .timed_window(10).map(lambda x: list(zip(*x)))
)
sink_func = partial(
    workflow.sink_to_arrow,
    sinks=stream_sinks,
    writers=stream_writers,
    output_func=lambda i: pa.OSFile(output_filename.format(i), "w"),
)
# sink_func = partial(client.loop.run_in_executor, None, sink_func)
# writer_stream.with_timeout(timeout=3, retries=2, failure_func=write_failures.append).sink(sink_func)
writer_stream.sink(sink_func)
# stored_data = writer_stream.sink_to_list()

# finished_futures_stream.sink(excepts(StopIteration, lambda x: new_futures_stream.emit(next(analysis_futures_iter)) if should_add_task()))
# new_futures_stream.sink_to_list()
all_futures = set()
finished_futures = set()
new_futures_stream.sink(lambda x: all_futures.add(x))
successful_futures_stream.sink(lambda x: finished_futures.add(x))

TASK_BUFFER_SIZE = 10000

# def should_add_task():
#     #return len([f for f in all_futures if f.status == 'pending'])
#     #return TASK_BUFFER_SIZE > len(all_futures - finished_futures)
#     return TASK_BUFFER_SIZE > len(all_futures - cancelled_futures)
#     #print('>',len(all_futures - finished_futures))
#     #return True

# def readd_task(x):
#     if should_add_task():
#         return new_futures_stream.emit(next(analysis_futures_iter))


def readd_task(x):
    num_tasks_needed = TASK_BUFFER_SIZE - len(
        all_futures - cancelled_futures - errored_futures
    )
    if num_tasks_needed > 0:
        for future in take(num_tasks_needed, analysis_futures_iter):
            new_futures_stream.emit(future)


finished_futures_stream.sink(excepts(StopIteration, readd_task))

# ac.update(take(3000, analysis_futures_iter))
for future in take(TASK_BUFFER_SIZE, analysis_futures_iter):
    new_futures_stream.emit(future)

gather_func = workflow.gather_stream(finished_futures_stream, ac)
gather_task = client.loop.asyncio_loop.create_task(gather_func)

# Old trench detection

In [None]:
%%time
find_trenches_diag = diagnostics.wrap_diagnostics(
    trench_detection.find_trenches, ignore_exceptions=True, pandas=True
)
trench_info_futures = {
    idx: client.submit(
        find_trenches_diag, client.submit(workflow.get_nd2_frame, **idx._asdict())
    )
    for idx, row in util.iter_index(frames_to_process)
}

In [None]:
progress(trench_info_futures)

In [None]:
client.cancel(trench_info_futures)

In [None]:
%%time
trench_info = util.apply_map_futures(
    client.gather, trench_info_futures, predicate=lambda x: x.status == "finished"
)

In [None]:
%%time
trench_points, trench_diag, trench_err = workflow.unzip_trench_info(trench_info)

In [None]:
len(trench_points)

# Segmentation

In [None]:
bad_angle = trench_diag["find_trench_lines.hough_2.angle"].abs() > 2
bad_angle.sum()

In [None]:
bad_pitch = (trench_diag["find_trench_lines.hough_2.peak_func.pitch"] - 24).abs() > 1
bad_pitch.sum()

In [None]:
selected = trench_diag[bad_pitch]  # trench_diag[bad_angle | bad_period]

In [None]:
frame_stream.event(_df=selected.index.to_frame(index=False))

In [None]:
%%time
trench_points_good = trench_points[~util.multi_join(trench_points.index, bad_pitch)]

In [None]:
(len(trench_points_good), len(trench_points_good) / len(trench_points))

In [None]:
%%time
trench_bbox_futures = []
for _, trenches in trench_points_good.groupby(["filename", "position", "t"]):
    trench_bbox_futures.append(
        client.submit(workflow.get_trench_bboxes, trenches, image_limits)
    )

In [None]:
%%time
trench_bbox_results = util.apply_map_futures(
    client.gather, trench_bbox_futures, predicate=lambda x: x.status == "finished"
)
trench_bboxes = pd.concat(
    [trench_points_good, pd.concat(trench_bbox_results, axis=0)], axis=1
)

In [None]:
%%time
%store trench_bboxes

In [None]:
%store -r trench_bboxes

In [None]:
trench_bboxes_t0 = util.get_one(trench_bboxes.groupby("t"))[1]
# trench_bboxes_t0.index = trench_points_good_t0.index.droplevel('t')

In [None]:
selected_trenches_segmentation = trench_bboxes_t0[
    trench_bboxes_t0[("info", "hough_value")] > 90
].loc[IDX[:, :, ["MCHERRY"], 0, :, :], :]

In [None]:
selected_trenches_segmentation.index = selected_trenches_segmentation.index.droplevel(
    "channel"
)

In [None]:
(len(trench_bboxes_t0), len(selected_trenches_segmentation) / len(trench_bboxes_t0))

In [None]:
# frames_to_analyze = all_frames.loc[IDX[:,:1,['MCHERRY','YFP'],1:5],:]
frames_to_analyze = all_frames.loc[IDX[:, :, ["MCHERRY", "YFP"], :], :]

In [None]:
(
    len(frames_to_analyze),
    len(all_frames.loc[IDX[:, :, ["MCHERRY", "YFP"], :], :]) / len(frames_to_analyze),
)

## New analysis

In [None]:
pixelwise_funcs = {
    "mean": np.mean,
    "min": np.min,
    "max": np.max,
    ("p0.3", "p0.5", "p0.7", "p0.9", "p0.95"): partial(
        np.percentile, q=(30, 50, 70, 90, 95)
    ),
}
trenchwise_funcs = {"sharpness": image.sharpness, **pixelwise_funcs}


def measurement_func(label_image, intensity_image):
    eroded_label_image = (
        util.repeat_apply(skimage.morphology.binary_erosion, 2)(label_image != 0)
        * label_image
    )
    #     plt.figure(figsize=(10,10))
    #     plt.imshow(label_image)
    #     plt.show()
    #     plt.figure(figsize=(10,10))
    #     plt.imshow(eroded_label_image);0/0
    if intensity_image is None:
        minlength = label_image.max() + 1
        mask_labelwise_df = pd.DataFrame(
            {
                ("noerode", "size"): np.bincount(label_image.flat, minlength=minlength),
                ("erode3", "size"): np.bincount(
                    eroded_label_image.flat, minlength=minlength
                ),
            }
        )
        mask_labelwise_df.index.name = "label"
        return dict(mask_labelwise=mask_labelwise_df)
    labelwise = {
        "noerode": workflow.map_frame_over_labels(
            pixelwise_funcs, label_image, intensity_image
        ),
        "erode3": workflow.map_frame_over_labels(
            pixelwise_funcs, eroded_label_image, intensity_image
        ),
    }
    trenchwise_df = workflow.map_frame(trenchwise_funcs, intensity_image)
    labelwise_df = pd.concat(labelwise, axis=1)
    return dict(trenchwise=trenchwise_df, labelwise=labelwise_df)


def _measure(
    trenches,
    frames,
    measurement_func,
    segmentation_channel="MCHERRY",
    measure_channels=None,
    segmentation_func=trench_segmentation.watershed.segment_trench,
    include_frame=True,
    frame_bits=8,
    frame_downsample=4,
    filename=None,
    position=None,
):
    frame_transformation = compose(
        processing.zarrify,
        partial(image.quantize, bits=frame_bits),
        partial(image.downsample, factor=frame_downsample),
    )
    trench_crops = processing._get_trench_crops(
        trenches,
        frames,
        include_frame=include_frame,
        frame_transformation=frame_transformation,
        filename=filename,
        position=position,
    )
    flattened_crops = {
        (*k[0], *k[1:]): v
        for k, v in util.flatten_dict(trench_crops).items()
        if k[0] != "_frame"
    }
    segmentation_masks = {}
    measurements = {}
    # segment
    for (trench_set, trench_idx, channel, t), crop in flattened_crops.items():
        if measure_channels is not None and channel not in measure_channels:
            continue
        segmentation_key = (trench_set, trench_idx, segmentation_channel, t)
        segmentation_mask = segmentation_masks.get(segmentation_key, None)
        if segmentation_mask is None:
            segmentation_masks[segmentation_key] = segmentation_func(
                flattened_crops[segmentation_key]
            )
            # measure mask
            measurements[("mask", (trench_set, trench_idx, t))] = measurement_func(
                segmentation_masks[segmentation_key], None
            )
        # measure
        measurements[(channel, (trench_set, trench_idx, t))] = measurement_func(
            segmentation_masks[segmentation_key], crop
        )
    measurement_dfs = util.map_dict_levels(lambda k: (k[1], k[0], *k[2:]), measurements)
    for name, dfs in measurement_dfs.items():
        dfs = util.unflatten_dict(dfs)
        if isinstance(util.get_one(dfs, level=2), pd.Series):
            df = pd.concat(
                {
                    channel: pd.concat(channel_dfs, axis=1).T
                    for channel, channel_dfs in dfs.items()
                },
                axis=1,
            )
        else:
            df = pd.concat(
                {
                    channel: pd.concat(channel_dfs, axis=0)
                    for channel, channel_dfs in dfs.items()
                },
                axis=1,
            )
        df.index.names = ["trench_set", "trench", "t", *df.index.names[3:]]
        measurement_dfs[name] = df
    images = dict(
        raw=trench_crops, segmentation=util.unflatten_dict(segmentation_masks)
    )
    return dict(measurements=measurement_dfs, images=images)


measure = processing.iterate_over_groupby(["filename", "position"])(_measure)

In [None]:
def filename_func(extension=None, kind=None, name=None, filename=None, position=None):
    components = [s for s in ("", name, extension) if s is not None]
    return os.path.join(
        f"{filename}.{kind}", "pos{:d}".format(position) + ".".join(components)
    )

In [None]:
pos = slice(None)  # slice(0, 2)
selected_trenches_segmentation2 = selected_trenches_segmentation.loc[
    IDX[["/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2"], pos, :, :, :], :
]
frames_to_analyze2 = frames_to_analyze.loc[
    IDX[["/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2"], pos], :
]
_measure_and_write = compose(
    partial(
        processing.write_images_and_measurements,
        filename_func=filename_func,
        dataframe_format="parquet",
    ),
    measure,
)
# measure_and_write = processing.iterate_over_groupby(['filename', 'position'])(_measure_and_write)
measure_and_write = processing.iterate_over_groupby(["filename", "position"])(
    partial(client.submit, compose(util.return_none, _measure_and_write))
)
futures = measure_and_write(
    selected_trenches_segmentation2, frames_to_analyze2, measurement_func
)

In [None]:
errors = list(
    util.apply_map_futures(
        list, futures, predicate=lambda x: x.status == "error"
    ).values()
)

In [None]:
errors[0].result()

In [None]:
%%time
d = pq.read_pandas(
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2.measurements/pos0.labelwise.parquet"
).to_pandas()

In [None]:
d

In [None]:
d.index.get_level_values("trench")

In [None]:
%%time
d = pa.open_file(
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2.measurements/pos1.labelwise.arrow"
).read_pandas()

In [None]:
d

## Streaming gather

In [None]:
%%time
ac = distributed.as_completed([], with_results=False, loop=client.loop)

new_futures_stream = streamz.Stream()
finished_futures_stream = streamz.Stream(asynchronous=True, loop=client.loop)

stream_sinks = {}
stream_writers = {}
output_filename = "/n/scratch2/jqs1/fidelity/all/output/analysis_full_stream11_{}.arrow"

new_futures_stream.sink(lambda x: ac.add(x))

errored_futures = set()
finished_futures_stream.filter(lambda x: x.status == "error").sink(
    lambda x: errored_futures.add(x)
)


def timeout_func(futures):
    ac.update(futures)


successful_futures_stream = finished_futures_stream.filter(
    lambda x: x.status == "finished"
)
# batched_futures_stream = successful_futures_stream.rate_limit(0.0004).timed_window(1)
# gathered_futures_stream = streamz.buffer(batched_futures_stream, 10).gather_and_cancel(client=client, cancel=True)
# batched_futures_stream = successful_futures_stream.timed_window(1)
batched_futures_stream = successful_futures_stream.rate_limit(0.01).timed_window(5)
buffered_futures_stream = (
    batched_futures_stream  # streamz.buffer(batched_futures_stream, 10)
)

cancelled_futures = set()
gathered_futures_stream = buffered_futures_stream.gather_and_cancel(
    client=client,
    gather=True,
    cancel=True,
    timeout=4,
    timeout_func=timeout_func,
    success_func=cancelled_futures.update,
)
# gathered_futures_stream.flatten().sink(partial(workflow.sink_to_arrow, sinks=stream_sinks, writers=stream_writers, output_func=lambda i: pa.OSFile(output_filename.format(i), 'w')))
write_failures = []
flattened_futures_stream = gathered_futures_stream.flatten()
writer_stream = (
    flattened_futures_stream  # .timed_window(10).map(lambda x: list(zip(*x)))
)
sink_func = partial(
    workflow.sink_to_arrow,
    sinks=stream_sinks,
    writers=stream_writers,
    output_func=lambda i: pa.OSFile(output_filename.format(i), "w"),
)
# sink_func = partial(client.loop.run_in_executor, None, sink_func)
# writer_stream.with_timeout(timeout=3, retries=2, failure_func=write_failures.append).sink(sink_func)
writer_stream.sink(sink_func)
# stored_data = writer_stream.sink_to_list()

# finished_futures_stream.sink(excepts(StopIteration, lambda x: new_futures_stream.emit(next(analysis_futures_iter)) if should_add_task()))
# new_futures_stream.sink_to_list()
all_futures = set()
finished_futures = set()
new_futures_stream.sink(lambda x: all_futures.add(x))
successful_futures_stream.sink(lambda x: finished_futures.add(x))

TASK_BUFFER_SIZE = 10000

# def should_add_task():
#     #return len([f for f in all_futures if f.status == 'pending'])
#     #return TASK_BUFFER_SIZE > len(all_futures - finished_futures)
#     return TASK_BUFFER_SIZE > len(all_futures - cancelled_futures)
#     #print('>',len(all_futures - finished_futures))
#     #return True

# def readd_task(x):
#     if should_add_task():
#         return new_futures_stream.emit(next(analysis_futures_iter))


def readd_task(x):
    num_tasks_needed = TASK_BUFFER_SIZE - len(
        all_futures - cancelled_futures - errored_futures
    )
    if num_tasks_needed > 0:
        for future in take(num_tasks_needed, analysis_futures_iter):
            new_futures_stream.emit(future)


finished_futures_stream.sink(excepts(StopIteration, readd_task))

# ac.update(take(3000, analysis_futures_iter))
for future in take(TASK_BUFFER_SIZE, analysis_futures_iter):
    new_futures_stream.emit(future)

gather_func = workflow.gather_stream(finished_futures_stream, ac)
gather_task = client.loop.asyncio_loop.create_task(gather_func)

In [None]:
ac.count()

In [None]:
gather_task

In [None]:
for future in take(5000, analysis_futures_iter):
    new_futures_stream.emit(future)
gather_func = workflow.gather_stream(finished_futures_stream, ac)
gather_task = client.loop.asyncio_loop.create_task(gather_func)

In [None]:
gather_task.cancel()
client.cancel(all_futures)

In [None]:
cluster.scale(200)

In [None]:
len(all_futures)

In [None]:
errored_futures[-1].result()

In [None]:
%%time
pa.open_stream(stream_sinks[2].r()).read_pandas()

In [None]:
util.apply_map_futures(
    client.gather, analysis_futures, predicate=lambda x: x.status == "error"
)

# Analysis

## Load data

In [None]:
%%time
framewise_df = data_io.read_parquet(
    "/n/scratch2/jqs1/fidelity/all/output/analysis_full_stream11_0.sorted.parquet4"
).to_pandas()

In [None]:
%%time
trenchwise_df = data_io.read_parquet(
    "/n/scratch2/jqs1/fidelity/all/output/analysis_full_stream11_1.sorted.parquet4"
).to_pandas()

In [None]:
trenchwise_df.columns = ["/".join(col).strip() for col in trenchwise_df.columns.values]

In [None]:
cols = [
    "filename",
    "position",
    "channel",
    "t",
    "trench_set",
    "trench",
    "label",
    "('YFP', 'labelwise', 'p0.9')",
    "('MCHERRY', 'labelwise', 'p0.9')",
    "('YFP', 'regionprops', 'area')",
]

In [None]:
%%time
labelwise_df = data_io.read_parquet(
    "/n/scratch2/jqs1/fidelity/all/output/analysis_full_stream11_2.sorted3.parquet4",
    columns=cols,
).to_pandas()

In [None]:
# TODO: otherwise computing is_unique is costly when we want to get_loc with full key
labelwise_df.index.__dict__["_cache"] = {"lexsort_depth": 6, "is_unique": True}
# labelwise_df.index.lexsort_depth # prime the cache
# if '_cache' not in labelwise_df.index.__dict__:
#     labelwise_df.index.__dict__['_cache'] = {}
# labelwise_df.index._cache['is_unique'] = True

In [None]:
labelwise_df.columns = ["/".join(col).strip() for col in labelwise_df.columns.values]

## Burst detection

In [None]:
yfp = "YFP/labelwise/p0.9"
mcherry = "MCHERRY/labelwise/p0.9"
area = "YFP/regionprops/area"
trench_key = ["filename", "position", "trench_set", "trench"]
trench_t_key = ["filename", "position", "trench_set", "trench", "t"]

In [None]:
%%time
# labelwise_selected = labelwise_df.loc[IDX['/n/scratch2/jqs1/fidelity/all/TrErr002_noBF.nd2',:],:]
labelwise_selected = labelwise_df.loc[
    IDX["/n/scratch2/jqs1/fidelity/all/180405_txnerr_loweronly_fast.nd2", :], :
]
# labelwise_selected = labelwise_df.loc[IDX['/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2.nd2',:],:]

In [None]:
len(labelwise_df) / len(labelwise_selected)

In [None]:
# col = 'YFP/p0.3'
# trenchwise_yfp_bg = trenchwise_df[col].rename(col+'_trenchwise')

In [None]:
%%time
background = labelwise_selected.loc[IDX[:, :, :, :, :, 0], [yfp, mcherry]]
background.index = background.index.droplevel("label")
background.columns = [c + "_bg" for c in background.columns]

In [None]:
%%time
cell_sized = labelwise_selected[labelwise_selected[area].between(100, 200)].loc[
    IDX[:, :, :, :, :, 1:], :
]

In [None]:
%%time
trench_t_median = cell_sized.groupby(trench_t_key).median()
trench_t_median.columns = [c + "_trench_t_median" for c in trench_t_median.columns]

In [None]:
%%time
trench_median = cell_sized.groupby(trench_key).median()
trench_median.columns = [c + "_trench_median" for c in trench_median.columns]

In [None]:
%%time
with_bg = util.multi_join(cell_sized, background)

In [None]:
%%time
with_bg = util.multi_join(with_bg, trench_t_median)

In [None]:
%%time
with_bg = util.multi_join(with_bg, trench_median)

In [None]:
%%time
bright_ts_median_t = pd.DataFrame(
    {
        "bright_ts_median_t_{}".format(thresh): (
            (with_bg[yfp] - with_bg[yfp + "_trench_t_median"]) >= thresh
        )
        .groupby(trench_key)
        .sum()
        for thresh in (
            5,
            8,
            10,
            20,
            30,
            50,
        )
    }
)

In [None]:
%%time
bright_ts_median = pd.DataFrame(
    {
        "bright_ts_median_{}".format(thresh): (
            (with_bg[yfp] - with_bg[yfp + "_trench_median"]) >= thresh
        )
        .groupby(trench_key)
        .sum()
        for thresh in (
            5,
            8,
            10,
            20,
            30,
            50,
        )
    }
)

In [None]:
%%time
bright_ts_bg = pd.DataFrame(
    {
        "bright_ts_bg_{}".format(thresh): (
            (with_bg[yfp] - with_bg[yfp + "_bg"]) >= thresh
        )
        .groupby(trench_key)
        .sum()
        for thresh in (
            5,
            8,
            10,
            20,
            30,
            50,
        )
    }
)

In [None]:
len(bright_ts_median_t[bright_ts_median_t["bright_ts_median_t_50"] > 1])

In [None]:
bright_ts_median_t[bright_ts_median_t["bright_ts_median_t_5"] > 1].head()

In [None]:
%%time
median_bg = background.groupby(trench_key).median()

In [None]:
bright_ts_all = util.multi_join(
    util.multi_join(
        util.multi_join(bright_ts_median_t, bright_ts_median), bright_ts_bg
    ),
    median_bg,
)

In [None]:
%%time
cell_sized_with_ts = util.multi_join(cell_sized, bright_ts_all)

In [None]:
%%time
detected_bursts = cell_sized_with_ts[cell_sized_with_ts["bright_ts_median_t_20"] >= 2]

In [None]:
%%time
len(detected_bursts.groupby(trench_key))

## New filter

In [None]:
%%time
detected_bursts2 = cell_sized_with_ts[cell_sized_with_ts["bright_ts_bg_20"] >= 2]

In [None]:
%%time
len(detected_bursts2.groupby(trench_key))

In [None]:
cell_sized_with_ts.columns

In [None]:
x = cell_sized_with_ts[yfp + "_bg"]
x[x < 180].hist(bins=50, log=True)

In [None]:
%%time
# detected_bursts3 = cell_sized_with_ts[(cell_sized_with_ts[yfp+'_bg'] <= 200) & (cell_sized_with_ts['bright_ts_bg_20'] >= 2)]
# detected_bursts3 = cell_sized_with_ts[(cell_sized_with_ts[yfp+'_bg'].between(120, 130)) & (cell_sized_with_ts['bright_ts_bg_20'] >= 1)]
detected_bursts3 = cell_sized_with_ts[
    (cell_sized_with_ts["bright_ts_median_t_50"] >= 3)
]

In [None]:
len(detected_bursts3)

In [None]:
%%time
len(detected_bursts3.groupby(trench_key))

## New visualization

In [None]:
LabelStream = ui.MultiIndexStream.define("LabelStream", labelwise_df.index)
label_stream = LabelStream()
box = ui.dataframe_browser(label_stream)
label_stream.event()
box

In [None]:
%%output size=100
%%opts Layout [normalize=False]
hover = HoverTool(
    tooltips=[
        ("(x,y)", "(@x{0[.]0}, @y{0[.]0})"),
        ("value", "@z"),
    ]
)
# cb = compose(partial(ui.hover_image, hover), ui._trench_img, workflow.get_trench_image)
cb = lambda v_max: compose(
    partial(ui.hover_image, hover),
    lambda x: x.redim.range(z=(0, v_max)),
    ui._trench_img,
    workflow.get_trench_image,
)
# cb = workflow.get_trench_image
(
    ui.trench_viewer(
        trench_bboxes, label_stream, channel="MCHERRY", image_callback=cb(5000)
    )
    + ui.trench_viewer(
        trench_bboxes, label_stream, channel="YFP", image_callback=cb(400)
    )
).cols(1)

In [None]:
groups = detected_bursts3.groupby(trench_key)
group_set_keys = list(util.grouper(groups.groups.keys(), 5))
group_index = pd.MultiIndex.from_tuples([(i,) for i in range(len(group_set_keys))])
group_index.names = ["group_set"]

In [None]:
GroupStream = ui.MultiIndexStream.define("GroupStream", group_index)
group_stream = GroupStream()
group_box = ui.dataframe_browser(group_stream)
group_stream.event()
group_box

In [None]:
%%output size=180
sel = Selection1D()


def callback(group_set):
    df = pd.concat([groups.get_group(key) for key in group_set_keys[group_set]])
    plot = hv.Scatter(
        df,
        kdims=["t"],
        vdims=[
            "YFP/labelwise/p0.9",
            "filename",
            "position",
            "trench_set",
            "trench",
            "label",
        ],
    )
    tooltips = [
        ("t", "@t{0[.]0}"),
        # ('filename', '@filename'),
        ("trench", "@position.@trench_set.@trench"),
        ("label", "@label"),
        ("YFP", "@{YFP/labelwise/p0.9}{0[.]0}"),
    ]
    hover = HoverTool(tooltips=tooltips)
    tap = TapTool()
    plot = plot.options(
        "Scatter",
        size=3,
        color_index="trench",
        nonselection_alpha=0.3,
        cmap="Category20",
        tools=[hover, tap],
        show_legend=True,
    )
    # ui.selection_to_stream(plot, label_stream)
    sel.clear()
    sel.add_subscriber(
        partial(
            ui._selection_to_stream_callback,
            data=plot.data,
            keys=df.index.names,
            stream=label_stream,
        )
    )
    return plot


p = hv.DynamicMap(callback, streams=[group_stream])
sel.source = p
p

## Movie output

In [None]:
%%output backend='matplotlib'
#%%opts Layout [normalize=False fig_inches=2 vspace=0 aspect_weight=1 sublabel_format='' tight=True title_format="{filename:}\npos: {position:} trench: {trench_set:}.{trench:} t: {t:}".format(**label_stream.contents) fontsize=20]
#%%opts Scatter [aspect=6]
key = tuple(getattr(label_stream, attr) for attr in trench_key)
index = detected_bursts.groupby(trench_key).get_group(key).index
ts = index._get_level_values(index._get_level_number("t"), unique=True)
# ts = list(range(3))

movie = (
    trench_movie(trench_bboxes, key, "MCHERRY", ts)
    + trench_movie(trench_bboxes, key, "YFP", ts)
    + scatter_movie(labelwise_df, label_stream.contents, ts)
    * hv.HoloMap(
        {t: hv.VLine(t).options(color="red", backend="matplotlib") for t in ts}
    )
).cols(1)
movie2 = movie.options(
    {
        "Layout": dict(
            normalize=False,
            framewise=True,
            fig_inches=7,
            vspace=0,
            aspect_weight=1,
            sublabel_format="",
            tight=False,
            fontsize=15,
            title_format="{filename:}\npos: {position:} trench: {trench_set:}.{trench:} t: {t:}".format(
                **label_stream.contents
            ),
        ),
        "Scatter": dict(aspect=6, s=20),
    },
    backend="matplotlib",
)
m = holomap_to_video(movie2, out="/tmp/jqsmovie.mp4", size=100, dpi=100)

In [None]:
Video("/tmp/jqsmovie.mp4", embed=True)

In [None]:
%%output backend='matplotlib'
%%opts Layout [normalize=False fig_inches=2 vspace=0 aspect_weight=1 sublabel_format='' tight=True title_format="{filename:}\npos: {position:} trench: {trench_set:}.{trench:} t: {t:}".format(**label_stream.contents) fontsize=20]
%%opts Scatter [aspect=6]
key = tuple(getattr(label_stream, attr) for attr in trench_key)
index = detected_bursts.groupby(trench_key).get_group(key).index
# ts = index._get_level_values(index._get_level_number('t'), unique=True)
ts = list(range(3))

movie = (
    trench_movie(trench_bboxes, key, "MCHERRY", ts)
    + trench_movie(trench_bboxes, key, "YFP", ts)
    + scatter_movie(labelwise_df, label_stream.contents, ts)
    * hv.HoloMap(
        {t: hv.VLine(t).options(color="red", backend="matplotlib") for t in ts}
    )
).cols(
    1
)  # .options('Layout', normalize=False)
m = holomap_to_video(movie, out="/tmp/jqsmovie.mp4")

## Other viz

In [None]:
%%output size=180
def cb(**kwargs):
    df = workflow.select_dataframe(
        labelwise_df, kwargs, t=slice(None), label=slice(None)
    )
    # df = workflow.select_dataframe(labelwise_df, kwargs, label=slice(None))
    plot = hv.Scatter(
        df,
        kdims=["t"],
        vdims=[
            "YFP/labelwise/p0.9",
            "filename",
            "position",
            "trench_set",
            "trench",
            "label",
        ],
    )
    tooltips = [
        ("t", "@t{0[.]0}"),
        # ('filename', '@filename'),
        ("trench", "@position.@trench_set.@trench"),
        ("label", "@label"),
        ("YFP", "@{YFP/labelwise/p0.9}{0[.]0}"),
    ]
    hover = HoverTool(tooltips=tooltips)
    plot = plot.options(
        "Scatter",
        size=3,
        color_index="label",
        nonselection_alpha=0.3,
        cmap="Category20",
        tools=[hover, "tap"],
        show_legend=True,
    )
    return plot


ui.viewer(cb, label_stream)