# Imports

In [None]:
import traceback
import warnings
from functools import partial
from importlib import reload
from operator import getitem

import cachetools
import dask
import holoplot.pandas
import holoviews as hv
import ipywidgets as widgets
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import param
import parambokeh
import qgrid
import skimage.morphology
import streamz
import zarr
from bokeh.models.tools import HoverTool
from cytoolz import compose, get_in
from dask import delayed
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from holoviews.operation.datashader import regrid
from holoviews.streams import Stream, param
from tqdm import tnrange, tqdm, tqdm_notebook
from traitlets import All

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from processing import *
# from trench_detection import *
# from trench_segmentation import *
# from trench_segmentation.watershed import *
# from util import *
# from ui import *
import common
import diagnostics
import image
import metadata
import trench_detection
import ui
import util
import workflow

In [None]:
%load_ext line_profiler
hv.extension("bokeh")
%matplotlib inline
tqdm.monitor_interval = 0

In [None]:
%store -r trench_df2 trench_df3

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="5:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # interface='ib0',
    memory="64GB",
    local_directory="/tmp",
    threads=1,
    processes=1,
    # diagnostics_port=('127.0.0.1', 8787),
    env_extra=['export PYTHONPATH="/home/jqs1/projects/matriarch"'],
)
client = Client(cluster)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
client = Client()

In [None]:
cluster.stop_workers(cluster.jobs)

# Functions

# Debugging

# Loading data

In [None]:
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2']#, '/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
nd2_filenames = [
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2",
    "/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2",
]
# nd2_filenames = ['/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC001.nd2', '/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC002.nd2']

In [None]:
# nd2s = {filename: client.submit(nd2reader.ND2Reader, filename, memmap=False) for filename in nd2_filenames}
nd2s = {
    filename: nd2reader.ND2Reader(filename, memmap=False) for filename in nd2_filenames
}

In [None]:
# nd2_sizes = util.apply_map_futures(client.gather, util.map_futures(partial(client.submit, lambda nd2: nd2.sizes), nd2s))
# nd2_parsed_metadata = util.apply_map_futures(client.gather, util.map_futures(partial(client.submit, lambda nd2: nd2.metadata), nd2s))
# nd2_metadata = util.apply_map_futures(client.gather, util.map_futures(partial(client.submit, metadata.parse_nd2_metadata), nd2s))
nd2_sizes = util.map_futures(lambda nd2: nd2.sizes, nd2s)
nd2_parsed_metadata = util.map_futures(lambda nd2: nd2.metadata, nd2s)
nd2_metadata = util.map_futures(metadata.parse_nd2_metadata, nd2s)

In [None]:
nd2_channels = {
    filename: md["channels"] for filename, md in nd2_parsed_metadata.items()
}

In [None]:
channels_to_idx = workflow.get_channels_to_indices(nd2_channels)

In [None]:
nd2_positions = workflow.get_position_metadata(nd2_metadata)

# Reload

In [None]:
client.upload_file("diagnostics.py")

In [None]:
def do_reload():
    from importlib import reload

    import diagnostics
    import trench_detection
    import util

    reload(util)
    reload(trench_detection)
    reload(diagnostics)


client.run(do_reload)
do_reload()

In [None]:
reload(util)

# Finding trenches

In [None]:
# get_trenches = util.wrap_diagnostics(trench_detection.get_trenches)
trench_data = {
    filename: {
        v: {
            channel: {
                t: client.submit(
                    trench_detection.get_trenches_diagnostics,
                    client.submit(
                        lambda x: x.get_frame_2D(
                            t=t, v=v, c=channels.index(channel), memmap=False
                        ),
                        nd2,
                    ),
                )
                for t in range(min(sizes["t"], 5000))
            }
            for channel in ("MCHERRY",)
        }
        for v in range(100)
    }
    for filename, nd2, sizes, metadata, channels in util.zip_dicts(
        nd2s, nd2_sizes, nd2_metadata, nd2_channels
    )
}

In [None]:
trench_rows = util.apply_map_futures(
    partial(client.map, diagnostics.wrapped_diagnostics_to_dataframe), trench_data
)

In [None]:
trench_rows_combined = util.apply_map_futures(
    client.gather, trench_rows, predicate=lambda x: x.status == "finished"
)

In [None]:
trench_df = util.map_collections(
    partial(pd.concat, axis=0, sort=True), trench_rows_combined, max_level=4
)
trench_df.index = trench_df.index.droplevel(-2)
trench_df.index.names = ["filename", "position", "channel", "t", "trench_set"]

In [None]:
trench_df2 = util.multi_join(trench_df, nd2_positions)

In [None]:
trench_df2 = util.multi_join(trench_df2, channels_to_idx)

# Compare setwise Hough vs. framewise Hough

In [None]:
trench_both = trench_df2.join(trench_df3, how="inner", lsuffix="", rsuffix="_sw")

In [None]:
num_ts = (
    trench_both.reset_index()
    .groupby(["filename", "position", "channel"])
    .agg({"t": "count"})
    .rename(columns={"t": "num_ts"})
)

In [None]:
bad_angle = trench_both["trench_rotation.hough_1.angle_sw"].abs() > 2

In [None]:
bad_period = ~(
    (trench_both["trench_anchors.periodogram_1.period_sw"] < 25)
    & (trench_both["trench_anchors.periodogram_1.period_sw"] > 23)
)

In [None]:
selected = trench_both[bad_angle]

In [None]:
num_selected_ts = (
    selected.reset_index()
    .groupby(["filename", "position", "channel"])
    .agg({"t": "count"})
    .rename(columns={"t": "num_selected_ts"})
)

In [None]:
trench_both2 = util.multi_join(util.multi_join(trench_both, num_ts), num_selected_ts)

# Prototyping

In [None]:
df = trench_both2[(trench_both2["num_selected_ts"] > 1) & bad_angle]

In [None]:
FrameStream = ui.DataframeStream.define("FrameStream", df.index.to_frame(index=False))
frame_stream = FrameStream()

box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
# ui.show_frame_info(trench_both2, frame_stream)

In [None]:
# g = ui.show_grid(df, stream=frame_stream)
# g

In [None]:
ui.image_viewer(frame_stream)

In [None]:
df = trench_both2[(trench_both2["num_selected_ts"] > 1) & bad_angle]
frame_stream.event(_df=df.index.to_frame(index=False))

In [None]:
reload(image)
reload(trench_detection)

In [None]:
%%time
frame = workflow.get_nd2_frame(**dict(frame_stream.get_param_values()))
_, diag, _ = diagnostics.wrap_diagnostics(trench_detection.get_trenches)(frame)

In [None]:
ui.show_plot_browser(diag, "labeling")

In [None]:
ui.show_plot_browser(diag, "label_1")

In [None]:
a = get_in("label_1.trench_rotation.hough_1.diff_h_std".split("."), diag)

In [None]:
y = a.Curve.I.data.y

In [None]:
import scipy

In [None]:
hv.Curve(scipy.ndimage.filters.gaussian_filter1d(y, 2))

In [None]:
np.convolve(np., y,mode='valid')

# Image processing prototyping

In [None]:
reload(trench_detection)

In [None]:
%%time
_, diag2, _ = diagnostics.wrap_diagnostics(trench_detection.get_trenches)(
    workflow.get_nd2_frame(**dict(frame_stream.get_param_values()))
)

In [None]:
ui.show_plot_browser(diag2, "labeling")

In [None]:
a = get_in("labeling.components".split("."), diag)

In [None]:
b = skimage.morphology.label(a.data)

In [None]:
ui.RevImage(b)

In [None]:
plt.imshow(b == 1)

In [None]:
Counter(b.flat)

In [None]:
np.bincount(a.data.flat)

In [None]:
np.bincount(b.flat)

# Old

In [None]:
get_trenches(root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos])

In [None]:
diag = tree()
_ = get_trenches(root_group["raw"][str(pos)][0, 1], diagnostics=diag)

In [None]:
def f(img_stack):
    ary = np.stack(
        [
            segment_trench(img_stack[t], diagnostics=None)
            for t in range(img_stack.shape[0])
        ],
        axis=0,
    )
    ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_seg_masks = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

In [None]:
def f(img_stack):
    return pd.Series(np.percentile(img_stack, 95, axis=(1, 2)))
    # return pd.Series(np.max(img_stack, axis=(1,2)))


trench_traces_all = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(100),
)