# Imports

In [None]:
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.feather as feather
import zarr
import dask
from dask import delayed
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
import streamz
import streamz.dataframe as sdf
import holoviews as hv
from holoviews.streams import Stream, param, Selection1D
from holoviews.operation.datashader import regrid
from bokeh.models.tools import HoverTool, TapTool
import matplotlib.pyplot as plt
import qgrid
import ipywidgets as widgets
from tqdm import tnrange, tqdm, tqdm_notebook
import warnings
from functools import partial
from cytoolz import *
from operator import getitem
import nd2reader
from importlib import reload
import traceback
import hvplot.pandas
#import param
#import parambokeh
#from traitlets import All
import cachetools
from collections import namedtuple, defaultdict
from collections.abc import Mapping, Sequence
from numbers import Number
import skimage.morphology
import scipy
from glob import glob
import os
import asyncio
from IPython.display import Video

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
########
from paulssonlab.image_analysis import *

# import common, trench_detection, util, data_io, processing
# import ui, diagnostics, metadata
# import workflow, image, geometry
# import trench_detection.hough, trench_detection.core
# import trench_segmentation.watershed

In [None]:
# %load_ext line_profiler
# %load_ext memory_profiler
# %load_ext snakeviz
# hv.extension("bokeh", "matplotlib")
# %matplotlib inline
# tqdm.monitor_interval = 0
# asyncio.get_event_loop().set_debug(False)
# import logging

# logging.basicConfig(level=logging.DEBUG)
# import warnings

# warnings.simplefilter("ignore")

# Loading data

In [None]:
nd2_filenames = ["/home/jqs1/scratch/jqs1/microscopy/211027/211027_fb_library_strong_sigw.nd2"]

In [None]:
all_frames, metadata = workflow.get_nd2_frame_list(nd2_filenames)
image_limits = workflow.get_filename_image_limits(metadata)

# Config

In [None]:
dask.config.get("distributed.worker.memory")

In [None]:
dask.config.config["distributed"]["worker"]["memory"] = {
    "target": 0.9,
    "spill": None,
    "pause": None,
    "terminate": 0.95,
}

In [None]:
# dask.config.config['distributed']['worker']['profile'] = {'interval': '10s', 'cycle': '10s'}
# {'interval': '10ms', 'cycle': '1000ms'}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="04:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # job_extra=['--exclude=compute-e-16-181,compute-e-16-186'],
    # interface='ib0',
    memory="10GB",  # TODO!!!
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/matriarch/log",
    cores=1,
    processes=1,
    # diagnostics_port=('127.0.0.1', 8787),
    env_extra=[
        'export PYTHONPATH="/home/jqs1/projects/matriarch"',
        #'export PYTHONTRACEMALLOC=25',
        #'export MALLOC_CONF=prof:true,prof_leak:true,lg_prof_interval:31,prof_final:true',
        'export LD_PRELOAD="/home/jqs1/lib/libjemalloc.so.2"',
    ],
)
client = Client(cluster)  # , direct_to_workers=True)

In [None]:
cluster.scale(300)

In [None]:
cluster.adapt(minimum=0, maximum=300)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
cluster.stop_jobs(cluster.running_jobs.keys())

In [None]:
cluster.scheduler.stop_services()
cluster.scheduler.stop()

In [None]:
client.restart()

# Reload

In [None]:
def do_reload():
    from importlib import reload
    import util, trench_detection, diagnostics, workflow, image

    # reload(util)
    # reload(trench_detection.hough)
    # reload(diagnostics)
    reload(workflow)
    # reload(image)


client.run(do_reload)
do_reload()

# Trench detection

In [None]:
FrameStream = ui.MultiIndexStream.define("FrameStream", all_frames.index)
frame_stream = FrameStream()
box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
%%time
# key = frame_stream.contents
key = ("/n/scratch2/jqs1/fidelity/all/181010_rpos_bigsnake.nd2", 35, "MCHERRY", 0)
frame = workflow.get_nd2_frame(*key)
find_trenches_diag = diagnostics.wrap_diagnostics(
    trench_detection.find_trenches, ignore_exceptions=False, pandas=False
)
trench_points, trench_diag, trench_err = find_trenches_diag(frame)

In [None]:
#%%output size=150
ui.show_plot_browser(trench_diag["label_2"])

# Data reduction

In [None]:
# selected_frames = all_frames.loc[IDX[:,:1,:,:2],:]
selected_frames = all_frames.loc[IDX[:, :, :, :], :]

## New trench detection+segmentation+analysis

#### Config

In [None]:
def filter_trenches(trenches):
    pitch = 20.9
    # pitch = 24
    if trenches is None:
        return None
    #     good_trenches = trenches[((trenches[('diag', 'find_trench_lines.hough_2.peak_func.pitch')] - 24).abs() <= 1)
    #                               & (trenches[('info','hough_value')] > 90)]
    # TODO: we shouldn't be filtering based on hough_value at all!!
    good_trenches = trenches[
        (
            (
                trenches[("diag", "find_trench_lines.hough_2.peak_func.pitch")] - pitch
            ).abs()
            <= 1
        )
        & (~trenches[("upper_left", "x")].isnull())
    ]
    # TODO: filter based on minimum trench length
    return good_trenches

In [None]:
pixelwise_funcs = {
    "mean": np.mean,
    "min": np.min,
    "max": np.max,
    ("p0.3", "p0.5", "p0.7", "p0.9", "p0.95"): partial(
        np.percentile, q=(30, 50, 70, 90, 95)
    ),
}
trenchwise_funcs = {"sharpness": image.sharpness, **pixelwise_funcs}


def _measurement_func(label_image, intensity_image):
    if label_image is not None:
        eroded_label_image = (
            util.repeat_apply(skimage.morphology.binary_erosion, 2)(label_image != 0)
            * label_image
        )
    if intensity_image is None:
        if label_image is None:
            return None  # can't measure anything
        minlength = label_image.max() + 1
        mask_labelwise_df = pd.DataFrame(
            {
                ("noerode", "size"): np.bincount(label_image.flat, minlength=minlength),
                ("erode2", "size"): np.bincount(
                    eroded_label_image.flat, minlength=minlength
                ),
            }
        )
        mask_labelwise_df.index.name = "label"
        return dict(mask_labelwise=mask_labelwise_df)
    trenchwise_df = workflow.map_frame(trenchwise_funcs, intensity_image)
    res = dict(trenchwise=trenchwise_df)
    if label_image is None:
        return res  # only measure trenchwise
    labelwise = {
        "noerode": workflow.map_frame_over_labels(
            pixelwise_funcs, label_image, intensity_image
        ),
        "erode2": workflow.map_frame_over_labels(
            pixelwise_funcs, eroded_label_image, intensity_image
        ),
    }
    labelwise_df = pd.concat(labelwise, axis=1)
    res["labelwise"] = labelwise_df
    return res  # measure trenchwise and labelwise

#### Boilerplate

In [None]:
def _measure(
    trenches,
    frames,
    measurement_func,
    segmentation_channel="MCHERRY",
    measure_channels=None,
    segmentation_func=trench_segmentation.watershed.segment_trench,
    include_frame=True,
    frame_bits=8,
    frame_downsample=4,
    filename=None,
    position=None,
):
    frame_transformation = compose(
        processing.zarrify,
        partial(image.quantize, bits=frame_bits),
        partial(image.downsample, factor=frame_downsample),
    )
    trench_crops = processing._get_trench_crops(
        trenches,
        frames,
        include_frame=include_frame,
        frame_transformation=frame_transformation,
        filename=filename,
        position=position,
    )
    # flattened_crops = {(*k[0], *k[1:]): v for k, v in util.flatten_dict(trench_crops).items() if k[0] != '_frame'}
    # flattened_crops = {k: v for k, v in util.flatten_dict(trench_crops).items() if k[0] != '_frame'}
    # print('trenches>',list(trench_crops[1][92]['YFP'].keys()))
    # print('trenches>',list(trench_crops[(1,92)]['MCHERRY'].keys()))
    # print('trenches only>',list(flattened_crops.keys()))
    # 0/0
    res = {}
    segmentation_masks = {}
    measurements = {}
    # segment
    for trench_set, crops_trench_channel_t in trench_crops.items():
        if trench_set == "_frame":
            continue
        for trench_idx, crops_channel_t in crops_trench_channel_t.items():
            for channel, crops_t in crops_channel_t.items():
                for t, crop in crops_t.items():
                    if measure_channels is not None and channel not in measure_channels:
                        continue
                    segmentation_key = (trench_set, trench_idx, segmentation_channel, t)
                    segmentation_mask = segmentation_masks.get(segmentation_key, None)
                    if segmentation_mask is None and segmentation_func is not None:
                        segmentation_mask = segmentation_func(
                            trench_crops[trench_set][trench_idx][segmentation_channel][
                                t
                            ]
                        )
                        segmentation_masks[segmentation_key] = segmentation_mask
                        # measure mask
                        if measurement_func is not None:
                            measurements[
                                ("mask", (trench_set, trench_idx, t))
                            ] = measurement_func(segmentation_mask, None)
                    # measure
                    if measurement_func is not None:
                        measurements[
                            (channel, (trench_set, trench_idx, t))
                        ] = measurement_func(segmentation_mask, crop)
    if measurement_func is not None:
        measurement_dfs = util.map_dict_levels(
            lambda k: (k[1], k[0], *k[2:]), measurements
        )
        for name, dfs in measurement_dfs.items():
            dfs = util.unflatten_dict(dfs)
            if isinstance(util.get_one(dfs, level=2), pd.Series):
                df = pd.concat(
                    {
                        channel: pd.concat(channel_dfs, axis=1).T
                        for channel, channel_dfs in dfs.items()
                    },
                    axis=1,
                )
            else:
                df = pd.concat(
                    {
                        channel: pd.concat(channel_dfs, axis=0)
                        for channel, channel_dfs in dfs.items()
                    },
                    axis=1,
                )
            df.index.names = ["trench_set", "trench", "t", *df.index.names[3:]]
            measurement_dfs[name] = df
        res["measurements"] = measurement_dfs
    images = dict(raw=trench_crops)
    if segmentation_func is not None:
        images["segmentation"] = util.unflatten_dict(segmentation_masks)
    res["images"] = images
    return res


measure = processing.iterate_over_groupby(["filename", "position"])(_measure)

In [None]:
def filename_func(extension=None, kind=None, name=None, filename=None, position=None):
    components = [s for s in ("", name, extension) if s is not None]
    if position is None:
        path = [f"{filename}.{kind}" + ".".join(components)]
    else:
        path = [f"{filename}.{kind}", "pos{:d}".format(position) + ".".join(components)]
    return os.path.join(*path)

In [None]:
def _trench_diag_to_dataframe(trench_diag, sep="."):
    df = trench_diag.to_frame().T
    expanded_df = diagnostics.expand_diagnostics_by_label(df)
    expanded_df.index = expanded_df.index.droplevel(0)
    expanded_df.index.names = [*expanded_df.index.names[:-1], "trench_set"]
    return expanded_df


#     if len(expanded_df):
#         expanded_df.index = expanded_df.index.droplevel(0)
#         expanded_df.index.names = [*expanded_df.index.names[:-1], 'trench_set']
#     else:
#         expanded_df = pd.concat([df], keys=[-1], names=['trench_set'])
#     return expanded_df


def _trench_info_to_dataframe(trench_info):
    trench_points, trench_diag, trench_err = trench_info
    if trench_err is not None:
        # TODO: write trench_err
        return None
    trench_diag = _trench_diag_to_dataframe(trench_info[1])
    # FROM: https://stackoverflow.com/questions/14744068/prepend-a-level-to-a-pandas-multiindex
    trench_diag = pd.concat([trench_diag], axis=1, keys=["diag"])
    trenches = pd.concat(
        [trench_points, util.multi_join(trench_info[0].index, trench_diag)], axis=1
    )
    return trenches


def _trenches_to_bboxes(trenches, image_limits):
    trench_bboxes = workflow.get_trench_bboxes(trenches, image_limits)
    if trench_bboxes is not None:
        trenches = pd.concat([trenches, trench_bboxes], axis=1)
    return trenches


find_trenches_diag = diagnostics.wrap_diagnostics(
    trench_detection.find_trenches, ignore_exceptions=True, pandas=True
)


def do_find_trenches(*key):
    frame = workflow.get_nd2_frame(*key)
    trench_info = find_trenches_diag(frame)
    return trench_info


def do_trenches_to_bboxes(trench_info, key=None, index_names=("filename", "position")):
    trenches = _trench_info_to_dataframe(trench_info)
    if trenches is None:
        return None
    if key is not None:
        trenches = pd.concat([trenches], names=index_names, keys=[key])
    trenches = _trenches_to_bboxes(trenches, image_limits=image_limits)
    return trenches


def do_get_trench_err(trench_info):
    trench_points, trench_diag, trench_err = trench_info
    if trench_err is None:
        return None
    if trench_points is not None:
        raise ValueError("expecting trench_points to be None")
    return trench_info


import pickle


def do_serialize_to_disk(
    data, filename, overwrite=True, skip_nones=True, format="pickle"
):
    if skip_nones:
        data = {k: v for k, v in data.items() if v is not None}
    if not overwrite and os.path.exists(filename):
        raise FileExistsError
    with open(filename, "wb") as f:
        if format == "arrow":
            buf = pa.serialize(data).to_buffer()
            f.write(buf)
        elif format == "pickle":
            pickle.dump(data, f)
    return data


def do_save_trenches(trenches, filename, overwrite=True):
    trenches = pd.concat(trenches)
    processing.write_dataframe_to_parquet(
        filename, trenches, merge=False, overwrite=overwrite
    )
    return trenches


def do_measure_and_write(trenches, frames, return_none=True, write=True, **kwargs):
    if trenches is None:
        return None
    trenches = filter_trenches(trenches)
    res = measure(trenches, frames, **kwargs)
    if write:
        processing.write_images_and_measurements(
            res,
            filename_func=filename_func,
            dataframe_format="parquet",
            write_images=True,
            write_measurements=True,
        )
    if return_none:
        return None
    else:
        return res

#### Execute

In [None]:
save_trench_err_futures = {}
all_analysis_futures = {}
save_trenches_futures = {}
save_trench_err_futures = {}

all_trench_bboxes_futures = {}  # TODO: just for debugging

for filename, filename_frames in selected_frames.groupby("filename"):
    # analysis_futures = {}
    trench_bboxes_futures = {}
    trench_err_futures = {}
    for position, frames in filename_frames.groupby("position"):
        key = (filename, position)
        frame_to_segment = frames.loc[
            IDX[:, :, ["MCHERRY"], 0], :
        ]  # TODO: make pluggable
        trenches_future = client.submit(
            do_find_trenches, *frame_to_segment.index[0], priority=10
        )
        trench_err_futures[key] = client.submit(do_get_trench_err, trenches_future)
        trench_bboxes_future = client.submit(
            do_trenches_to_bboxes, trenches_future, (filename, position), priority=10
        )
        trench_bboxes_futures[key] = trench_bboxes_future
        all_trench_bboxes_futures[key] = trench_bboxes_future
        analysis_future = client.submit(
            do_measure_and_write,
            trench_bboxes_future,
            frames,
            measurement_func=_measurement_func,
            # measurement_func=None,
            # segmentation_func=None,
            return_none=True,
            write=True,
            priority=-10,
        )
        # analysis_futures[key] = analysis_future
        all_analysis_futures[key] = analysis_future
    # save trenches
    trenches_filename = filename_func(
        kind="trenches", extension="parquet", filename=filename
    )
    save_trenches_futures[filename] = client.submit(
        do_save_trenches,
        list(dict(sorted(trench_bboxes_futures.items())).values()),
        trenches_filename,
        priority=100,
    )
    trench_errs_filename = filename_func(
        kind="trench_errs", extension="pickle", filename=filename
    )
    save_trench_err_futures[filename] = client.submit(
        do_serialize_to_disk, trench_err_futures, trench_errs_filename, priority=100
    )
# OPTIONALLY: stream analysis to master

In [None]:
util.apply_map_futures(
    client.gather, trench_bboxes_futures, predicate=lambda x: x.status == "error"
)

In [None]:
%%time
t = do_find_trenches(
    "/n/scratch2/jqs1/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2",
    35,
    "MCHERRY",
    0,
)

In [None]:
t[1]["label_1.find_trench_lines.hough_2.peak_func.pitch"]

In [None]:
%%time
tt = do_trenches_to_bboxes(
    t, ("/n/scratch2/jqs1/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2", 35)
)

In [None]:
tt

In [None]:
client.gather(save_trenches_futures)

In [None]:
%%time
do_find_trenches(
    "/n/scratch2/jqs1/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2",
    1,
    "MCHERRY",
    0,
)

In [None]:
a.loc[
    IDX[("/n/scratch2/jqs1/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2", 1)],
    :,
]

In [None]:
client.gather(util.get_one(all_analysis_futures))

In [None]:
{k: v for k, v in all_analysis_futures.items() if v.status != "pending"}

In [None]:
all_analysis_futures

In [None]:
util.apply_map_futures(
    client.gather, all_analysis_futures, predicate=lambda x: x.status == "error"
)

In [None]:
client.gather(all_analysis_futures)

In [None]:
a = util.get_one(client.gather(save_trenches_futures))

In [None]:
z = a[("diag", "find_trench_lines.hough_2.peak_func.pitch")]
z[z < 50].plot.hist(bins=100, log=True)

In [None]:
z = a[("info", "hough_value")]
z.plot.hist(bins=100)

In [None]:
client.gather(save_trench_err_futures)

In [None]:
client.restart()

# Analysis

## New loader

In [None]:
%%time
labelwise_df = data_io.read_parquet(
    "/home/jqs1/scratch/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2.measurements/pos1050.labelwise.parquet"
)
mask_labelwise_df = data_io.read_parquet(
    "/home/jqs1/scratch/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2.measurements/pos0.mask_labelwise.parquet"
)
trenchwise_df = data_io.read_parquet(
    "/home/jqs1/scratch/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2.measurements/pos0.trenchwise.parquet"
)

In [None]:
%%time
from concurrent import futures
from util import tqdm_auto


def _load_measurements(base_filename):
    labelwise_filename = base_filename + ".labelwise.parquet"
    if not os.path.exists(labelwise_filename):
        return
    # labelwise_df = data_io.read_parquet(labelwise_filename, progress_bar=None)
    # open(labelwise_filename, 'rb').read()
    with pa.OSFile(labelwise_filename, "rb") as f:
        # res = f.read()
        # buf = f.read()
        # buf = f.read_buffer()
        res = pq.ParquetFile(f).read(use_pandas_metadata=True, use_threads=False)
        # res = res.to_pandas()
        # res = data_io.read_parquet(buf, progress_bar=None)
        # res = pq.read_pandas(buf)
        # res = data_io.read_parquet(f.read_buffer(), progress_bar=None)
    return
    # mask_labelwise_df = data_io.read_parquet(base_filename+'.mask_labelwise.parquet')
    # trenchwise_df = data_io.read_parquet(base_filename+'.trenchwise.parquet')
    # return labelwise_df#.to_pandas()


def load_measurements(filename, nthreads=False, progress_bar=tqdm_auto):
    tasks = []
    for i in range(100):
        base_filename = os.path.join(filename, "pos{:d}".format(i))
        tasks.append((_load_measurements, base_filename))
    if nthreads:
        ex = futures.ThreadPoolExecutor(max_workers=nthreads)
        completed_tasks = futures.as_completed([ex.submit(*t) for t in tasks])
        if progress_bar is not None:
            completed_tasks = progress_bar(completed_tasks, total=len(tasks))
        for future in completed_tasks:
            res = future.result()
            # print('>',res.iloc[0])
    else:
        if progress_bar is not None:
            tasks = progress_bar(tasks, total=len(tasks))
        for t in tasks:
            res = t[0](*t[1:])
    return


load_measurements(
    "/home/jqs1/scratch/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2.measurements"
)

In [None]:
%%time
from concurrent import futures
from util import tqdm_auto


def _load_measurements(base_filename, kind):
    labelwise_filename = base_filename + ".{}.parquet".format(kind)
    if not os.path.exists(labelwise_filename):
        return
    with pa.OSFile(labelwise_filename, "rb") as f:
        res = pq.ParquetFile(f).read(use_pandas_metadata=True, use_threads=False)
    res = res.to_pandas()
    return res
    # labelwise_df = data_io.read_parquet(labelwise_filename, progress_bar=None)
    # mask_labelwise_df = data_io.read_parquet(base_filename+'.mask_labelwise.parquet')
    # trenchwise_df = data_io.read_parquet(base_filename+'.trenchwise.parquet')
    # return labelwise_df#.to_pandas()


def load_measurements(
    parquet_filename, kind, nthreads=False, progress_bar=tqdm_auto, filename=None
):
    positions = range(100)
    if progress_bar is not None:
        positions = progress_bar(positions, total=len(positions))
    res = {}
    for pos in positions:
        key = (filename, pos)
        base_filename = os.path.join(parquet_filename, "pos{:d}".format(pos))
        res[key] = _load_measurements(base_filename, kind)
    return res

In [None]:
labelwise_dfs = load_measurements(
    "/home/jqs1/scratch/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2.measurements",
    "labelwise",
    filename="/home/jqs1/scratch/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2",
)
mask_labelwise_dfs = load_measurements(
    "/home/jqs1/scratch/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2.measurements",
    "mask_labelwise",
    filename="/home/jqs1/scratch/fidelity/all/180928_txnerr_bigsnake.stripe-1.256m.nd2",
)

In [None]:
%%time
labelwise_df = pd.concat(labelwise_dfs)

In [None]:
%%time
mask_labelwise_df = pd.concat(mask_labelwise_dfs)

In [None]:
mask_labelwise_df.index = labelwise_df.index

In [None]:
%%time
labelwise_df = pd.concat([labelwise_df, mask_labelwise_df], axis=1)

## Load data

In [None]:
%%time
framewise_df = data_io.read_parquet(
    "/n/scratch2/jqs1/fidelity/all/output/analysis_full_stream11_0.sorted.parquet4"
).to_pandas()

In [None]:
%%time
trenchwise_df = data_io.read_parquet(
    "/n/scratch2/jqs1/fidelity/all/output/analysis_full_stream11_1.sorted.parquet4"
).to_pandas()

In [None]:
trenchwise_df.columns = ["/".join(col).strip() for col in trenchwise_df.columns.values]

In [None]:
cols = [
    "filename",
    "position",
    "channel",
    "t",
    "trench_set",
    "trench",
    "label",
    "('YFP', 'labelwise', 'p0.9')",
    "('MCHERRY', 'labelwise', 'p0.9')",
    "('YFP', 'regionprops', 'area')",
]

In [None]:
%%time
labelwise_df = data_io.read_parquet(
    "/n/scratch2/jqs1/fidelity/all/output/analysis_full_stream11_2.sorted3.parquet4",
    columns=cols,
).to_pandas()

In [None]:
# TODO: otherwise computing is_unique is costly when we want to get_loc with full key
labelwise_df.index.__dict__["_cache"] = {"lexsort_depth": 6, "is_unique": True}
# labelwise_df.index.lexsort_depth # prime the cache
# if '_cache' not in labelwise_df.index.__dict__:
#     labelwise_df.index.__dict__['_cache'] = {}
# labelwise_df.index._cache['is_unique'] = True

In [None]:
labelwise_df.columns = ["/".join(col).strip() for col in labelwise_df.columns.values]

## Burst detection

In [None]:
yfp = "YFP/labelwise/p0.9"
mcherry = "MCHERRY/labelwise/p0.9"
area = "YFP/regionprops/area"
trench_key = ["filename", "position", "trench_set", "trench"]
trench_t_key = ["filename", "position", "trench_set", "trench", "t"]

In [None]:
%%time
# labelwise_selected = labelwise_df.loc[IDX['/n/scratch2/jqs1/fidelity/all/TrErr002_noBF.nd2',:],:]
# labelwise_selected = labelwise_df.loc[IDX['/n/scratch2/jqs1/fidelity/all/180405_txnerr_loweronly_fast.nd2',:],:]
# labelwise_selected = labelwise_df.loc[IDX['/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2.nd2',:],:]
labelwise_selected = labelwise_df

In [None]:
len(labelwise_df) / len(labelwise_selected)

In [None]:
# col = 'YFP/p0.3'
# trenchwise_yfp_bg = trenchwise_df[col].rename(col+'_trenchwise')

In [None]:
labelwise_selected.index.names

In [None]:
%%time
background = labelwise_selected.loc[IDX[:, :, :, :, :, 0], [yfp, mcherry]]
background.index = background.index.droplevel("label")
background.columns = [c + "_bg" for c in background.columns]

In [None]:
%%time
cell_sized = labelwise_selected[labelwise_selected[area].between(100, 200)].loc[
    IDX[:, :, :, :, :, 1:], :
]

In [None]:
%%time
trench_t_median = cell_sized.groupby(trench_t_key).median()
trench_t_median.columns = [c + "_trench_t_median" for c in trench_t_median.columns]

In [None]:
%%time
trench_median = cell_sized.groupby(trench_key).median()
trench_median.columns = [c + "_trench_median" for c in trench_median.columns]

In [None]:
%%time
with_bg = util.multi_join(cell_sized, background)

In [None]:
%%time
with_bg = util.multi_join(with_bg, trench_t_median)

In [None]:
%%time
with_bg = util.multi_join(with_bg, trench_median)

In [None]:
%%time
bright_ts_median_t = pd.DataFrame(
    {
        "bright_ts_median_t_{}".format(thresh): (
            (with_bg[yfp] - with_bg[yfp + "_trench_t_median"]) >= thresh
        )
        .groupby(trench_key)
        .sum()
        for thresh in (
            5,
            8,
            10,
            20,
            30,
            50,
        )
    }
)

In [None]:
%%time
bright_ts_median = pd.DataFrame(
    {
        "bright_ts_median_{}".format(thresh): (
            (with_bg[yfp] - with_bg[yfp + "_trench_median"]) >= thresh
        )
        .groupby(trench_key)
        .sum()
        for thresh in (
            5,
            8,
            10,
            20,
            30,
            50,
        )
    }
)

In [None]:
%%time
bright_ts_bg = pd.DataFrame(
    {
        "bright_ts_bg_{}".format(thresh): (
            (with_bg[yfp] - with_bg[yfp + "_bg"]) >= thresh
        )
        .groupby(trench_key)
        .sum()
        for thresh in (
            5,
            8,
            10,
            20,
            30,
            50,
        )
    }
)

In [None]:
len(bright_ts_median_t[bright_ts_median_t["bright_ts_median_t_50"] > 1])

In [None]:
bright_ts_median_t[bright_ts_median_t["bright_ts_median_t_5"] > 1].head()

In [None]:
%%time
median_bg = background.groupby(trench_key).median()

In [None]:
bright_ts_all = util.multi_join(
    util.multi_join(
        util.multi_join(bright_ts_median_t, bright_ts_median), bright_ts_bg
    ),
    median_bg,
)

In [None]:
%%time
cell_sized_with_ts = util.multi_join(cell_sized, bright_ts_all)

In [None]:
%%time
detected_bursts = cell_sized_with_ts[cell_sized_with_ts["bright_ts_median_t_20"] >= 2]

In [None]:
%%time
len(detected_bursts.groupby(trench_key))

## New filter

In [None]:
%%time
detected_bursts2 = cell_sized_with_ts[cell_sized_with_ts["bright_ts_bg_20"] >= 2]

In [None]:
%%time
len(detected_bursts2.groupby(trench_key))

In [None]:
cell_sized_with_ts.columns

In [None]:
x = cell_sized_with_ts[yfp + "_bg"]
x[x < 180].hist(bins=50, log=True)

In [None]:
%%time
# detected_bursts3 = cell_sized_with_ts[(cell_sized_with_ts[yfp+'_bg'] <= 200) & (cell_sized_with_ts['bright_ts_bg_20'] >= 2)]
# detected_bursts3 = cell_sized_with_ts[(cell_sized_with_ts[yfp+'_bg'].between(120, 130)) & (cell_sized_with_ts['bright_ts_bg_20'] >= 1)]
detected_bursts3 = cell_sized_with_ts[
    (cell_sized_with_ts["bright_ts_median_t_50"] >= 3)
]

In [None]:
len(detected_bursts3)

In [None]:
%%time
len(detected_bursts3.groupby(trench_key))

## New visualization

In [None]:
LabelStream = ui.MultiIndexStream.define("LabelStream", labelwise_df.index)
label_stream = LabelStream()
box = ui.dataframe_browser(label_stream)
label_stream.event()
box

In [None]:
%%output size=100
%%opts Layout [normalize=False]
hover = HoverTool(
    tooltips=[
        ("(x,y)", "(@x{0[.]0}, @y{0[.]0})"),
        ("value", "@z"),
    ]
)
# cb = compose(partial(ui.hover_image, hover), ui._trench_img, workflow.get_trench_image)
cb = lambda v_max: compose(
    partial(ui.hover_image, hover),
    lambda x: x.redim.range(z=(0, v_max)),
    ui._trench_img,
    workflow.get_trench_image,
)
# cb = workflow.get_trench_image
(
    ui.trench_viewer(
        trench_bboxes, label_stream, channel="MCHERRY", image_callback=cb(5000)
    )
    + ui.trench_viewer(
        trench_bboxes, label_stream, channel="YFP", image_callback=cb(400)
    )
).cols(1)

In [None]:
groups = detected_bursts3.groupby(trench_key)
group_set_keys = list(util.grouper(groups.groups.keys(), 5))
group_index = pd.MultiIndex.from_tuples([(i,) for i in range(len(group_set_keys))])
group_index.names = ["group_set"]

In [None]:
GroupStream = ui.MultiIndexStream.define("GroupStream", group_index)
group_stream = GroupStream()
group_box = ui.dataframe_browser(group_stream)
group_stream.event()
group_box

In [None]:
%%output size=180
sel = Selection1D()


def callback(group_set):
    df = pd.concat([groups.get_group(key) for key in group_set_keys[group_set]])
    plot = hv.Scatter(
        df,
        kdims=["t"],
        vdims=[
            "YFP/labelwise/p0.9",
            "filename",
            "position",
            "trench_set",
            "trench",
            "label",
        ],
    )
    tooltips = [
        ("t", "@t{0[.]0}"),
        # ('filename', '@filename'),
        ("trench", "@position.@trench_set.@trench"),
        ("label", "@label"),
        ("YFP", "@{YFP/labelwise/p0.9}{0[.]0}"),
    ]
    hover = HoverTool(tooltips=tooltips)
    tap = TapTool()
    plot = plot.options(
        "Scatter",
        size=3,
        color_index="trench",
        nonselection_alpha=0.3,
        cmap="Category20",
        tools=[hover, tap],
        show_legend=True,
    )
    # ui.selection_to_stream(plot, label_stream)
    sel.clear()
    sel.add_subscriber(
        partial(
            ui._selection_to_stream_callback,
            data=plot.data,
            keys=df.index.names,
            stream=label_stream,
        )
    )
    return plot


p = hv.DynamicMap(callback, streams=[group_stream])
sel.source = p
p

## Movie output

In [None]:
%%output backend='matplotlib'
#%%opts Layout [normalize=False fig_inches=2 vspace=0 aspect_weight=1 sublabel_format='' tight=True title_format="{filename:}\npos: {position:} trench: {trench_set:}.{trench:} t: {t:}".format(**label_stream.contents) fontsize=20]
#%%opts Scatter [aspect=6]
key = tuple(getattr(label_stream, attr) for attr in trench_key)
index = detected_bursts.groupby(trench_key).get_group(key).index
ts = index._get_level_values(index._get_level_number("t"), unique=True)
# ts = list(range(3))

movie = (
    trench_movie(trench_bboxes, key, "MCHERRY", ts)
    + trench_movie(trench_bboxes, key, "YFP", ts)
    + scatter_movie(labelwise_df, label_stream.contents, ts)
    * hv.HoloMap(
        {t: hv.VLine(t).options(color="red", backend="matplotlib") for t in ts}
    )
).cols(1)
movie2 = movie.options(
    {
        "Layout": dict(
            normalize=False,
            framewise=True,
            fig_inches=7,
            vspace=0,
            aspect_weight=1,
            sublabel_format="",
            tight=False,
            fontsize=15,
            title_format="{filename:}\npos: {position:} trench: {trench_set:}.{trench:} t: {t:}".format(
                **label_stream.contents
            ),
        ),
        "Scatter": dict(aspect=6, s=20),
    },
    backend="matplotlib",
)
m = holomap_to_video(movie2, out="/tmp/jqsmovie.mp4", size=100, dpi=100)

In [None]:
Video("/tmp/jqsmovie.mp4", embed=True)

In [None]:
%%output backend='matplotlib'
%%opts Layout [normalize=False fig_inches=2 vspace=0 aspect_weight=1 sublabel_format='' tight=True title_format="{filename:}\npos: {position:} trench: {trench_set:}.{trench:} t: {t:}".format(**label_stream.contents) fontsize=20]
%%opts Scatter [aspect=6]
key = tuple(getattr(label_stream, attr) for attr in trench_key)
index = detected_bursts.groupby(trench_key).get_group(key).index
# ts = index._get_level_values(index._get_level_number('t'), unique=True)
ts = list(range(3))

movie = (
    trench_movie(trench_bboxes, key, "MCHERRY", ts)
    + trench_movie(trench_bboxes, key, "YFP", ts)
    + scatter_movie(labelwise_df, label_stream.contents, ts)
    * hv.HoloMap(
        {t: hv.VLine(t).options(color="red", backend="matplotlib") for t in ts}
    )
).cols(
    1
)  # .options('Layout', normalize=False)
m = holomap_to_video(movie, out="/tmp/jqsmovie.mp4")

## Other viz

In [None]:
%%output size=180
def cb(**kwargs):
    df = workflow.select_dataframe(
        labelwise_df, kwargs, t=slice(None), label=slice(None)
    )
    # df = workflow.select_dataframe(labelwise_df, kwargs, label=slice(None))
    plot = hv.Scatter(
        df,
        kdims=["t"],
        vdims=[
            "YFP/labelwise/p0.9",
            "filename",
            "position",
            "trench_set",
            "trench",
            "label",
        ],
    )
    tooltips = [
        ("t", "@t{0[.]0}"),
        # ('filename', '@filename'),
        ("trench", "@position.@trench_set.@trench"),
        ("label", "@label"),
        ("YFP", "@{YFP/labelwise/p0.9}{0[.]0}"),
    ]
    hover = HoverTool(tooltips=tooltips)
    plot = plot.options(
        "Scatter",
        size=3,
        color_index="label",
        nonselection_alpha=0.3,
        cmap="Category20",
        tools=[hover, "tap"],
        show_legend=True,
    )
    return plot


ui.viewer(cb, label_stream)