# Imports

In [None]:
import numpy as np
import pandas as pd
import zarr
import dask
from dask import delayed
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
import streamz
import holoviews as hv
from holoviews.streams import Stream, param
from holoviews.operation.datashader import regrid
from bokeh.models.tools import HoverTool
import matplotlib.pyplot as plt
import qgrid
import ipywidgets as widgets
from tqdm import tnrange, tqdm, tqdm_notebook
import warnings
from functools import partial
from cytoolz import compose, get_in
from operator import getitem
import nd2reader
from importlib import reload
import traceback
import holoplot.pandas
import param
import parambokeh
from traitlets import All
import cachetools
from collections import namedtuple
import skimage.morphology
import scipy

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from processing import *
# from trench_detection import *
# from trench_segmentation import *
# from trench_segmentation.watershed import *
# from util import *
# from ui import *
import common, trench_detection, util, ui, diagnostics, metadata, workflow, image

In [None]:
%load_ext line_profiler
hv.extension("bokeh")
%matplotlib inline
tqdm.monitor_interval = 0

# Restore data

In [None]:
%store -r trench_df4

In [None]:
trench_df = trench_df4

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="5:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # interface='ib0',
    memory="64GB",
    local_directory="/tmp",
    threads=1,
    processes=1,
    # diagnostics_port=('127.0.0.1', 8787),
    env_extra=['export PYTHONPATH="/home/jqs1/projects/matriarch"'],
)
client = Client(cluster)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
# client = Client()

In [None]:
# cluster.stop_workers(cluster.jobs)

# Loading data

In [None]:
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2']#, '/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
nd2_filenames = [
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2",
    "/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2",
]
# nd2_filenames = ['/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC001.nd2', '/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC002.nd2']

In [None]:
all_frames, metadata, parsed_metadata = workflow.get_nd2_frame_list(nd2_filenames)

# Reload

In [None]:
def do_reload():
    from importlib import reload
    import util, trench_detection, diagnostics

    reload(util)
    reload(trench_detection)
    reload(diagnostics)


client.run(do_reload)
do_reload()

# Finding trenches

In [None]:
frames_to_process = all_frames.loc[IDX[:, :50, ["MCHERRY"], 0], :]

In [None]:
len(frames_to_process)

## Frame quality finding

In [None]:
radial_psd2 = compose(image.radial_profile, psd2)
frame_psd2s_futures = {
    idx: client.submit(
        radial_psd2, client.submit(workflow._get_nd2_frame, **idx._asdict())
    )
    for idx, row in util.iter_index(frames_to_process)
}

In [None]:
frame_psd2s = client.gather(frame_psd2s_futures)

In [None]:
FrameStream = ui.DataframeStream.define(
    "FrameStream", frames_to_process.index.to_frame(index=False)
)
frame_stream = FrameStream()

box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
%%opts Layout [shared_axes=False]
dict_viewer(
    frame_psd2s, frame_stream, wrapper=lambda k, v: hv.Curve(np.log(v))
) + ui.image_viewer(frame_stream)

## Run trench finding

In [None]:
trench_data = {
    all_frames.index[row.Index]: client.submit(
        trench_detection.get_trenches_diag,
        client.submit(
            workflow._get_nd2_frame,
            row.filename,
            row.position,
            row.channel,
            row.t,
            memmap=False,
        ),
        find_angle_setwise=True,
    )
    for row in all_frames.reset_index().itertuples()
}

In [None]:
client.cancel(trench_data)

In [None]:
trench_rows = util.apply_map_futures(
    partial(client.map, diagnostics.wrapped_diagnostics_to_dataframe), trench_data
)

In [None]:
trench_rows_combined = util.apply_map_futures(
    client.gather, trench_rows, predicate=lambda x: x.status == "finished"
)

## Analysis

In [None]:
trench_df = util.map_collections(
    partial(pd.concat, axis=0, sort=True), trench_rows_combined, max_level=4
)
trench_df.index = trench_df.index.droplevel(-2)
trench_df.index.names = ["filename", "position", "channel", "t", "trench_set"]

In [None]:
bad_angle = trench_df["trench_rotation.hough_2.angle"].abs() > 2
bad_angle.sum()

In [None]:
bad_period = (trench_df["trench_anchors.periodogram_2.period"] - 24).abs() > 2
bad_period.sum()

In [None]:
(bad_angle | bad_period).sum()

In [None]:
len(trench_df)

In [None]:
num_ts = (
    trench_df.reset_index()
    .groupby(["filename", "position", "channel"])
    .agg({"t": "count"})
    .rename(columns={"t": "num_ts"})
)

In [None]:
selected = frames_to_process

In [None]:
selected = trench_df[bad_period]

In [None]:
num_selected_ts = (
    selected.reset_index()
    .groupby(["filename", "position", "channel"])
    .agg({"t": "count"})
    .rename(columns={"t": "num_selected_ts"})
)

In [None]:
trench_df = util.multi_join(util.multi_join(trench_df, num_ts), num_selected_ts)

In [None]:
# frame_stream.event(_df=selected.index.to_frame(index=False))

# Prototyping

In [None]:
FrameStream = ui.DataframeStream.define(
    "FrameStream", selected.index.to_frame(index=False)
)
frame_stream = FrameStream()

box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
ui.show_frame_info(trench_df, frame_stream)

In [None]:
# g = ui.show_grid(df, stream=frame_stream)
# g

In [None]:
ui.image_viewer(frame_stream)

In [None]:
reload(image)
reload(trench_detection)

In [None]:
%%time
frame = workflow.get_nd2_frame(**dict(frame_stream.get_param_values()))
_, diag, _ = diagnostics.wrap_diagnostics(trench_detection.get_trenches)(frame)

In [None]:
ui.show_plot_browser(diag, "label_1.trench_anchors")

In [None]:
y = get_in(
    "label_1.trench_anchors.trench_anchor_profile".split("."), diag
).Curve.I.data[:, 1]

In [None]:
plt.plot(y)

In [None]:
f, Pxx = scipy.signal.periodogram(y)

In [None]:
plt.plot(f, Pxx)

In [None]:
idx = Pxx.argmax()

In [None]:
f[idx]

In [None]:
1 / f[idx]

In [None]:
diag["label_1"]["trench_rotation"]["hough_2"].keys()

In [None]:
z = get_in("label_1.trench_rotation.hough_2.log_hough".split("."), diag).data

In [None]:
z2 = z[:, z.shape[1] // 2]

In [None]:
plt.plot(np.trim_zeros(z2))

In [None]:
f2, Pxx2 = scipy.signal.periodogram(np.trim_zeros(z2))

In [None]:
hv.Curve(zip(f2, Pxx2))

In [None]:
idx2 = Pxx2.argmax()

In [None]:
1 / f2[idx2]

# Low-frequency components

In [None]:
frame1 = workflow.get_nd2_frame(**dict(frame_stream.get_param_values()))

In [None]:
plt.imshow(frame1)

In [None]:
reload(image)

In [None]:
def psd2(img):
    return np.abs(np.fft.fftshift(np.fft.fft2(img))) ** 2

In [None]:
a = psd2(frame1)
b = a / a.mean()

In [None]:
plt.plot(image.radial_profile(np.log(psd2(frame1))))

In [None]:
plt.plot(np.log(image.radial_profile(psd2(frame1))))
plt.plot(np.log(image.radial_profile(psd2(frame2))))

In [None]:
plt.imshow(np.log(b))

In [None]:
plt.imshow(psd2(frame1))

In [None]:
plt.plot(image.radial_profile(frame1))

# Image processing prototyping

In [None]:
reload(trench_detection)

In [None]:
%%time
_, diag2, _ = diagnostics.wrap_diagnostics(trench_detection.get_trenches)(
    workflow.get_nd2_frame(**dict(frame_stream.get_param_values()))
)

In [None]:
ui.show_plot_browser(diag2, "labeling")

In [None]:
a = get_in("labeling.components".split("."), diag)

In [None]:
b = skimage.morphology.label(a.data)

# Old

In [None]:
get_trenches(root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos])

In [None]:
diag = tree()
_ = get_trenches(root_group["raw"][str(pos)][0, 1], diagnostics=diag)

In [None]:
def f(img_stack):
    ary = np.stack(
        [
            segment_trench(img_stack[t], diagnostics=None)
            for t in range(img_stack.shape[0])
        ],
        axis=0,
    )
    ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_seg_masks = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

In [None]:
def f(img_stack):
    return pd.Series(np.percentile(img_stack, 95, axis=(1, 2)))
    # return pd.Series(np.max(img_stack, axis=(1,2)))


trench_traces_all = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(100),
)