# Imports

In [None]:
import numpy as np
import pandas as pd
import zarr
import dask
from dask import delayed
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
import streamz
import holoviews as hv
from holoviews.streams import Stream, param
from holoviews.operation.datashader import regrid
from bokeh.models.tools import HoverTool
import matplotlib.pyplot as plt
import qgrid
import ipywidgets as widgets
from tqdm import tnrange, tqdm, tqdm_notebook
import warnings
from functools import partial
from cytoolz import compose, get_in
from operator import getitem
import nd2reader
from importlib import reload
import traceback
import hvplot.pandas
import param
import parambokeh
from traitlets import All
import cachetools
from collections import namedtuple
import skimage.morphology
import scipy

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from processing import *
# from trench_detection import *
# from trench_segmentation import *
# from trench_segmentation.watershed import *
# from util import *
# from ui import *
import common, trench_detection, util
import ui, diagnostics, metadata
import workflow, image, geometry
import trench_detection.hough, trench_detection.core

In [None]:
%load_ext line_profiler
hv.extension("bokeh")
%matplotlib inline
tqdm.monitor_interval = 0

# Restore data

In [None]:
%store -r trench_df4

In [None]:
trench_df = trench_df4

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="2:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # interface='ib0',
    memory="16GB",
    local_directory="/tmp",
    cores=1,
    processes=1,
    # diagnostics_port=('127.0.0.1', 8787),
    env_extra=['export PYTHONPATH="/home/jqs1/projects/matriarch"'],
)
client = Client(cluster)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
# client = Client()

In [None]:
cluster.stop_workers(cluster.jobs)

# Loading data

In [None]:
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2']#, '/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
nd2_filenames = [
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2",
    "/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2",
]
# nd2_filenames = ['/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC001.nd2', '/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC002.nd2']

In [None]:
all_frames, metadata, parsed_metadata = workflow.get_nd2_frame_list(nd2_filenames)

# Reload

In [None]:
def do_reload():
    from importlib import reload
    import util, trench_detection, diagnostics, workflow, image

    reload(util)
    reload(trench_detection)
    reload(diagnostics)
    reload(workflow)
    reload(image)


client.run(do_reload)
do_reload()

# Finding trenches

In [None]:
frames_to_process = all_frames.loc[IDX[:, :, ["MCHERRY"], :10], :]

In [None]:
len(frames_to_process)

## Frame quality finding

In [None]:
radial_psd2 = compose(image.radial_profile, image.psd2)
frame_psd2s_futures = {
    idx: client.submit(
        radial_psd2, client.submit(workflow._get_nd2_frame, **idx._asdict())
    )
    for idx, row in util.iter_index(frames_to_process)
}

In [None]:
frame_psd2s = client.gather(frame_psd2s_futures)

In [None]:
FrameStream = ui.DataframeStream.define(
    "FrameStream", frames_to_process.index.to_frame(index=False)
)
frame_stream = FrameStream()

box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
%%opts Layout [shared_axes=False]
dict_viewer(
    frame_psd2s, frame_stream, wrapper=lambda k, v: hv.Curve(np.log(v))
) + ui.image_viewer(frame_stream)

## Run trench finding

In [None]:
# locally: get trench_points dict?? (how to organize? use dict proxy to index into it?)
# where do I list all trenches, so that I can map over them?? e.g., compute per-timepoint focus
# turn trench_points into df
# locally: get diag df (by trench_set)
# then dask

In [None]:
# trench_points_futures = {idx: client.submit(get_trenches,
#                                            client.submit(workflow._get_nd2_frame, **idx._asdict())) for idx, row in util.iter_index(frames_to_process)}

In [None]:
find_trenches_diag = diagnostics.wrap_diagnostics(
    trench_detection.hough.find_trenches, ignore_exceptions=True, pandas=True
)
trench_info_futures = {
    idx: client.submit(
        find_trenches_diag, client.submit(workflow.get_nd2_frame, **idx._asdict())
    )
    for idx, row in util.iter_index(frames_to_process)
}

In [None]:
client.cancel(trench_info_futures)

In [None]:
%%time
trench_info = util.apply_map_futures(
    client.gather, trench_info_futures, predicate=lambda x: x.status == "finished"
)

In [None]:
len(trench_info)

In [None]:
errs = {k: v[2] for k, v in trench_info.items() if v[2] is not None}
errs

In [None]:
%%time
trench_points, trench_diag, trench_err = workflow.unzip_trench_info(trench_info)

In [None]:
len(trench_points)

In [None]:
%%time
%store trench_points
%store trench_diag

## Analysis

In [None]:
trench_diag.tail()

In [None]:
bad_angle = trench_diag["find_trench_lines.hough_2.angle"].abs() > 2
bad_angle.sum()

In [None]:
bad_pitch = (trench_diag["find_trench_lines.hough_2.pitch"] - 24).abs() > 1
bad_pitch.sum()

In [None]:
selected = trench_diag[bad_pitch]  # trench_diag[bad_angle | bad_period]

In [None]:
bad_pitch.to_frame().join(trench_points)

In [None]:
bad_pitch.sum()

In [None]:
%%time
trench_points_good = trench_points[~util.multi_join(trench_points.index, bad_pitch)]

In [None]:
trench_points_good.head()

In [None]:
(len(trench_points_good), len(trench_points_good) / len(trench_points))

In [None]:
frame_stream.event(_df=selected.index.to_frame(index=False))

# Trench finding QA

In [None]:
FrameStream = ui.DataframeStream.define(
    "FrameStream", selected.index.to_frame(index=False)
)
frame_stream = FrameStream()

box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
ui.image_viewer(frame_stream)

In [None]:
ui.show_frame_info(trench_diag, frame_stream)

In [None]:
# g = ui.show_grid(df, stream=frame_stream)
# g

In [None]:
%%time
frame = workflow.get_nd2_frame(**dict(frame_stream.get_param_values()))
# frame = workflow.get_nd2_frame(filename='/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2', position=0, channel='MCHERRY', t=0)
_, diag, _ = diagnostics.wrap_diagnostics(trench_detection.hough.find_trenches)(frame)

In [None]:
ui.show_plot_browser(diag)

# Segmentation

In [None]:
trench_points.index.names

In [None]:
trenches = trench_points.loc[
    IDX["/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2", :3, ["MCHERRY"], 0, :, :],
    :,
]

In [None]:
trenches.head()

In [None]:
trenches.tail()

In [None]:
trench_set_points = (
    np.array(list(trenches.loc[:, "bottom"].values)),
    np.array(list(trenches.loc[:, "top"].values)),
)

In [None]:
frame = workflow.get_nd2_frame(**trenches.index.to_frame().iloc[0].to_dict())

In [None]:
x_lim, y_lim = geometry.get_image_limits(frame.shape)

In [None]:
geometry.get_trench_bbox(trench_set_points, 0, x_lim, y_lim)

In [None]:
for x in trenches.groupby(["filename", "position", "trench_set"]):
    print(x)

# Old

In [None]:
get_trenches(root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos])

In [None]:
diag = tree()
_ = get_trenches(root_group["raw"][str(pos)][0, 1], diagnostics=diag)

In [None]:
def f(img_stack):
    ary = np.stack(
        [
            segment_trench(img_stack[t], diagnostics=None)
            for t in range(img_stack.shape[0])
        ],
        axis=0,
    )
    ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_seg_masks = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

In [None]:
def f(img_stack):
    return pd.Series(np.percentile(img_stack, 95, axis=(1, 2)))
    # return pd.Series(np.max(img_stack, axis=(1,2)))


trench_traces_all = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(100),
)