# Imports

In [None]:
import numpy as np
import pandas as pd
import zarr
import dask
from dask import delayed
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
import streamz
import holoviews as hv
from holoviews.streams import Stream, param
from holoviews.operation.datashader import regrid
from bokeh.models.tools import HoverTool
import matplotlib.pyplot as plt
import qgrid
import ipywidgets as widgets
from tqdm import tnrange, tqdm, tqdm_notebook
import warnings
from functools import partial
from cytoolz import compose, get_in
from operator import getitem
import nd2reader
from importlib import reload
import traceback
import holoplot.pandas
import param
import parambokeh
from traitlets import All
import cachetools
from collections import namedtuple
import skimage.morphology
import scipy

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from processing import *
# from trench_detection import *
# from trench_segmentation import *
# from trench_segmentation.watershed import *
# from util import *
# from ui import *
import common, trench_detection, util, ui, diagnostics, metadata, workflow, image

In [None]:
%load_ext line_profiler
hv.extension("bokeh")
%matplotlib inline
tqdm.monitor_interval = 0

# Restore data

In [None]:
%store -r trench_df4

In [None]:
trench_df = trench_df4

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="5:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # interface='ib0',
    memory="32GB",
    local_directory="/tmp",
    threads=1,
    processes=1,
    # diagnostics_port=('127.0.0.1', 8787),
    env_extra=['export PYTHONPATH="/home/jqs1/projects/matriarch"'],
)
client = Client(cluster)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
# client = Client()

In [None]:
cluster.stop_workers(cluster.jobs)

# Loading data

In [None]:
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2']#, '/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
nd2_filenames = [
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2",
    "/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2",
]
# nd2_filenames = ['/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC001.nd2', '/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC002.nd2']

In [None]:
all_frames, metadata, parsed_metadata = workflow.get_nd2_frame_list(nd2_filenames)

# Reload

In [None]:
def do_reload():
    from importlib import reload
    import util, trench_detection, diagnostics, workflow, image

    reload(util)
    reload(trench_detection)
    reload(diagnostics)
    reload(workflow)
    reload(image)


client.run(do_reload)
do_reload()

# Finding trenches

In [None]:
frames_to_process = all_frames.loc[IDX[:, :, ["MCHERRY"], 0], :]

In [None]:
len(frames_to_process)

## Frame quality finding

In [None]:
radial_psd2 = compose(image.radial_profile, image.psd2)
frame_psd2s_futures = {
    idx: client.submit(
        radial_psd2, client.submit(workflow._get_nd2_frame, **idx._asdict())
    )
    for idx, row in util.iter_index(frames_to_process)
}

In [None]:
frame_psd2s = client.gather(frame_psd2s_futures)

In [None]:
FrameStream = ui.DataframeStream.define(
    "FrameStream", frames_to_process.index.to_frame(index=False)
)
frame_stream = FrameStream()

box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
%%opts Layout [shared_axes=False]
dict_viewer(
    frame_psd2s, frame_stream, wrapper=lambda k, v: hv.Curve(np.log(v))
) + ui.image_viewer(frame_stream)

## Run trench finding

In [None]:
# locally: get trench_points dict?? (how to organize? use dict proxy to index into it?)
# where do I list all trenches, so that I can map over them?? e.g., compute per-timepoint focus
# turn trench_points into df
# locally: get diag df (by trench_set)
# then dask

In [None]:
# trench_points_futures = {idx: client.submit(get_trenches,
#                                            client.submit(workflow._get_nd2_frame, **idx._asdict())) for idx, row in util.iter_index(frames_to_process)}

In [None]:
trench_info_futures = {
    idx: client.submit(
        trench_detection.get_trenches_diag,
        client.submit(workflow.get_nd2_frame, **idx._asdict()),
    )
    for idx, row in util.iter_index(frames_to_process)
}

In [None]:
client.cancel(trench_info_futures)

In [None]:
trench_info = util.apply_map_futures(
    client.gather, trench_info_futures, predicate=lambda x: x.status == "finished"
)

In [None]:
len(trench_info)

In [None]:
{k: v[2] for k, v in trench_info.items() if v[2] is not None}

In [None]:
trench_points, trench_diag, trench_err = workflow.unzip_trench_info(trench_info)

In [None]:
len(trench_points)

In [None]:
%store -r trench_diag

## Analysis

In [None]:
trench_diag.tail()

In [None]:
bad_angle = trench_diag["trench_rotation.hough_2.angle"].abs() > 2
bad_angle.sum()

In [None]:
bad_period = (trench_diag["trench_anchors.periodogram_2.period"] - 24).abs() > 2
bad_period.sum()

In [None]:
selected = trench_diag[bad_period]  # trench_diag[bad_angle | bad_period]

In [None]:
frame_stream.event(_df=selected.index.to_frame(index=False))

# Prototyping

In [None]:
FrameStream = ui.DataframeStream.define(
    "FrameStream", selected.index.to_frame(index=False)
)
frame_stream = FrameStream()

box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
ui.show_frame_info(trench_diag, frame_stream)

In [None]:
# g = ui.show_grid(df, stream=frame_stream)
# g

In [None]:
ui.image_viewer(frame_stream)

In [None]:
from importlib import reload

reload(image)
reload(workflow)
reload(trench_detection)
reload(diagnostics)
reload(trench_detection.hough)

In [None]:
import trench_detection
import trench_detection.hough

In [None]:
%%time
frame = workflow.get_nd2_frame(**dict(frame_stream.get_param_values()))
_, diag, _ = diagnostics.wrap_diagnostics(trench_detection.hough.find_trenches)(frame)

In [None]:
trench_detection.core.edge_point(np.array((1, 0)), np.pi * 2, (0, 1), (0, 1))

In [None]:
frame_rot = skimage.transform.rotate(frame, -37)

In [None]:
_, diag2, _ = diagnostics.wrap_diagnostics(trench_detection.hough.find_trenches)(
    frame_rot
)

In [None]:
ui.show_plot_browser(diag2)
# ui.show_plot_browser(diag2, 'label_1.find_trench_ends');

# Low-frequency components

In [None]:
frame1 = workflow.get_nd2_frame(**dict(frame_stream.get_param_values()))

In [None]:
plt.imshow(frame1)

In [None]:
reload(image)

In [None]:
def psd2(img):
    return np.abs(np.fft.fftshift(np.fft.fft2(img))) ** 2

In [None]:
a = psd2(frame1)
b = a / a.mean()

In [None]:
plt.plot(image.radial_profile(np.log(psd2(frame1))))

In [None]:
plt.plot(np.log(image.radial_profile(psd2(frame1))))
plt.plot(np.log(image.radial_profile(psd2(frame2))))

In [None]:
plt.imshow(np.log(b))

In [None]:
plt.imshow(psd2(frame1))

In [None]:
plt.plot(image.radial_profile(frame1))

# Image processing prototyping

In [None]:
reload(trench_detection)

In [None]:
%%time
_, diag2, _ = diagnostics.wrap_diagnostics(trench_detection.get_trenches)(
    workflow.get_nd2_frame(**dict(frame_stream.get_param_values()))
)

In [None]:
ui.show_plot_browser(diag2, "labeling")

In [None]:
a = get_in("labeling.components".split("."), diag)

In [None]:
b = skimage.morphology.label(a.data)

# Lines

In [None]:
import geometry

In [None]:
x_lim, y_lim = geometry.get_image_limits(frame.shape)

In [None]:
x0 = (x_lim[0], y_lim[0])
x1 = (x_lim[1], y_lim[1])

In [None]:
plt.plot(*zip(x0, x1))

In [None]:
theta = np.pi / 2
sep = 1 / np.cos(np.pi / 4 - theta)
diagonal = np.sqrt(x_lim[1] ** 2 + y_lim[1] ** 2)
s = np.arange(0, diagonal, sep)[np.newaxis, :]
anchors = (np.array((0, 0)) * (1 - s) + s * np.array((x_lim[1], y_lim[1]))).T
# anchors = [(1000,1000)]
for x0, x_m, x1 in trench_detection.core.line_array(
    anchors, theta, x_lim, y_lim, bidirectional=True
):
    plt.plot(*zip(x0, x1))

In [None]:
theta = np.pi / 2 - np.pi / 3
pitch = 24
##########
# anchors
x_min = y_min = 0
x_max, y_max = frame.shape
x_lim = (x_min, x_max)
y_lim = (y_min, y_max)
# anchor0, anchor1 = trench_detection.core.get_edge_points(np.pi/2-theta, x_lim, y_lim)
anchors = np.ones(2)[np.newaxis, :] * np.arange(0, x_max, pitch)[:, np.newaxis]
plt.figure(figsize=(12, 12))
plt.imshow(frame)
plt.gca().add_artist(plt.Circle(anchors[0], 50, color="g"))
plt.gca().add_artist(plt.Circle(anchors[-1], 50, color="gray"))
# lines = line_array(point_linspace(anchor0, anchor1, int((anchor1[0] - anchor0[0])//spacing)), theta, x_lim, y_lim, start=500, stop=700)
lines = trench_detection.core.line_array(
    anchors, theta, x_lim, y_lim, bidirectional=True
)
for x_m, x0, x1 in lines:
    line = np.vstack((x0, x1)).T
    plt.plot(*line, color="w")
    plt.gca().add_artist(plt.Circle(x0, 10, color="r", zorder=2))
    plt.gca().add_artist(plt.Circle(x1, 10, color="r", zorder=2))

In [None]:
lines = trench_detection.core.line_array()

# Old

In [None]:
get_trenches(root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos])

In [None]:
diag = tree()
_ = get_trenches(root_group["raw"][str(pos)][0, 1], diagnostics=diag)

In [None]:
def f(img_stack):
    ary = np.stack(
        [
            segment_trench(img_stack[t], diagnostics=None)
            for t in range(img_stack.shape[0])
        ],
        axis=0,
    )
    ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_seg_masks = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

In [None]:
def f(img_stack):
    return pd.Series(np.percentile(img_stack, 95, axis=(1, 2)))
    # return pd.Series(np.max(img_stack, axis=(1,2)))


trench_traces_all = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(100),
)