# Imports

In [None]:
import asyncio
import os
import traceback
import warnings
from functools import partial
from glob import glob
from importlib import reload

import dask
import distributed
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import scipy
import zarr
from cytoolz import *
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
########
import nd2reader
from matriarch import *

# import common, trench_detection, util, data_io, processing
# import ui, diagnostics, metadata
# import workflow, image, geometry
# import trench_detection.hough, trench_detection.core
# import trench_segmentation.watershed

In [None]:
%load_ext line_profiler
%load_ext memory_profiler
# %load_ext snakeviz
hv.extension("bokeh", "matplotlib")
%matplotlib inline
# tqdm.monitor_interval = 0
# asyncio.get_event_loop().set_debug(False)
# import logging
# logging.basicConfig(level=logging.DEBUG)
import warnings

warnings.simplefilter("ignore")

# Loading data

In [None]:
# nd2_filenames = ['/home/jqs1/scratch/190509/190509_YFP_mScarlet_repressilators_fast.nd2']
# nd2_filenames = ['/home/jqs1/scratch/190509/190509_YFP_mScarlet_repressilators_faster.nd2']
nd2_filenames = [
    "/home/jqs1/scratch/190504/basilisk/190504_YFP_mScarlet_repressilators_fast.nd2"
]
# nd2_filenames = ['/home/jqs1/scratch/190504/basilisk/190504_YFP_mScarlet_repressilators002.nd2']
# nd2_filenames = ['/home/jqs1/scratch/190411_FP_Ti3/190411_mV_SCFP_repr.nd2']

In [None]:
all_frames, metadata, parsed_metadata = workflow.get_nd2_frame_list(nd2_filenames)
image_limits = workflow.get_filename_image_limits(metadata)

# Config

In [None]:
dask.config.get("distributed.worker.memory")

In [None]:
dask.config.config["distributed"]["worker"]["memory"] = {
    "target": 0.9,
    "spill": None,
    "pause": None,
    "terminate": 0.95,
}

In [None]:
# dask.config.config['distributed']['worker']['profile'] = {'interval': '10s', 'cycle': '10s'}
# {'interval': '10ms', 'cycle': '1000ms'}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="04:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # job_extra=['--exclude=compute-e-16-181,compute-e-16-186'],
    # interface='ib0',
    memory="12GB",  # TODO!!!
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/matriarch/log",
    cores=1,
    processes=1,
)  # ,
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH="/home/jqs1/projects/matriarch"',
#'export PYTHONTRACEMALLOC=25',
#'export MALLOC_CONF=prof:true,prof_leak:true,lg_prof_interval:31,prof_final:true',
#           'export LD_PRELOAD="/home/jqs1/lib/libjemalloc.so.2"'])
client = Client(cluster)  # , direct_to_workers=True)

In [None]:
cluster.scale(4)

In [None]:
cluster

In [None]:
cluster.adapt(minimum=0, maximum=300)

In [None]:
cluster.stop_jobs(cluster.running_jobs.keys())

In [None]:
cluster.scheduler.stop_services()
cluster.scheduler.stop()

In [None]:
client.restart()

# Reload

In [None]:
def do_reload():
    from importlib import reload

    import diagnostics
    import image
    import trench_detection
    import util
    import workflow

    # reload(util)
    # reload(trench_detection.hough)
    # reload(diagnostics)
    reload(workflow)
    # reload(image)


client.run(do_reload)
do_reload()

# Trench detection

In [None]:
pitches[pitches > 19]

In [None]:
FrameStream = ui.MultiIndexStream.define("FrameStream", all_frames.index)
frame_stream = FrameStream()
box = ui.dataframe_browser(frame_stream)
frame_stream.event()
box

In [None]:
%output size=200
ui.image_viewer(frame_stream)

In [None]:
cluster.scale(0)

In [None]:
%%time
key = tuple(frame_stream.contents[k] for k in ("filename", "position", "channel", "t"))
# key = ('/home/jqs1/scratch/190509/190509_YFP_mScarlet_repressilators_faster.nd2', 35, 'MCHERRY', 0)
frame = workflow.get_nd2_frame(*key)
# find_trenches_diag = diagnostics.wrap_diagnostics(trench_detection.find_trenches, ignore_exceptions=False, pandas=False)
# trench_points, trench_diag, trench_err = find_trenches_diag(frame)

In [None]:
# %%output size=150
ui.show_plot_browser(trench_diag)

In [None]:
%%time
diag = {}
img = frame[500:600, :500]
img_labels = trench_segmentation.segment(img, diagnostics=diag)

In [None]:
diag["img_k1_frangi"].redim.range(z=(0, 1e-1))

In [None]:
diag["img_k1_frangi"].redim.range(z=(0, 1e-1))

In [None]:
diag["img_k1_frangi"].redim.range(z=(0, 1e-1))

In [None]:
diag["img_k1"]

In [None]:
ui.show_plot_browser(diag)

# Drift correction

In [None]:
nd2 = nd2reader.ND2Reader(nd2_filenames[0])

In [None]:
nd2.sizes

In [None]:
nd2.metadata["channels"]

In [None]:
%%time
imgs = [nd2.get_frame_2D(v=10, c=0, t=t) for t in range(nd2.sizes["t"])]

In [None]:
from skimage.feature import register_translation

In [None]:
%%time
shift, error, diffphase = register_translation(imgs[0], imgs[1])

In [None]:
%%time
shifts = [
    register_translation(imgs[t], imgs[t + 1], return_error=False)
    for t in range(len(imgs) - 1)
]

In [None]:
%%time
shifts_from_0 = [
    register_translation(imgs[0], imgs[t + 1], return_error=False)
    for t in range(len(imgs) - 1)
]

In [None]:
%%time
s = 2
shifts3 = [
    register_translation(imgs[t], imgs[t + s], return_error=False)
    for t in range(len(imgs) - s)
]

In [None]:
%%time
s = 2
shifts_avg = [
    register_translation(
        np.mean(imgs[t : t + s], axis=0),
        np.mean(imgs[t + s : t + 2 * s]),
        return_error=False,
    )
    for chunk in grouper()
]

In [None]:
plt.plot(np.cumsum(np.array(shifts3)[::s], axis=0))
plt.plot(np.arange(len(shifts_from_0)) / s, np.array(shifts_from_0))

In [None]:
plt.plot(np.cumsum(np.array(shifts3)[::s], axis=0))

In [None]:
plt.plot(np.array(shifts_from_0))

In [None]:
plt.plot(np.cumsum(shifts, axis=0))

In [None]:
plt.plot(np.array(shifts))

In [None]:
shift

In [None]:
period = np.product(list(util.get_keys(nd2.sizes, "c", "v").values()))

In [None]:
np.asarray(nd2._parser._raw_metadata.x_data)[::period]

In [None]:
plt.plot(np.asarray(nd2._parser._raw_metadata.x_data)[::period])

In [None]:
plt.plot(np.asarray(nd2._parser._raw_metadata.y_data)[0::period])

In [None]:
plt.plot(np.asarray(nd2._parser._raw_metadata.z_data)[0::period])

# Data reduction

In [None]:
selected_frames = all_frames.loc[IDX[:, 100:102, :, :], :]

## Debug trench detection

In [None]:
segmentation_channel = "CFP"
segmentation_t = selected_frames.index.labels[selected_frames.index.names.index("t")][
    -1
]  # last timepoint

find_trenches_diag = diagnostics.wrap_diagnostics(
    trench_detection.find_trenches, ignore_exceptions=True, pandas=True
)


def do_find_trenches(*key):
    frame = workflow.get_nd2_frame(*key)
    trench_info = find_trenches_diag(frame)
    return trench_info


trenches = {}
for filename, filename_frames in selected_frames.groupby("filename"):
    for position, frames in filename_frames.groupby("position"):
        key = (filename, position)
        frame_to_segment = frames.loc[
            IDX[:, :, [segmentation_channel], segmentation_t], :
        ]  # TODO: make pluggable
        trenches_future = client.submit(do_find_trenches, *frame_to_segment.index[0])
        #         trenches_future = do_find_trenches(*frame_to_segment.index[0])
        trenches[key] = trenches_future

In [None]:
cluster.scale(30)

In [None]:
trenches_res = client.gather(trenches)

In [None]:
trenches_df = pd.DataFrame({k: v[1] for k, v in trenches_res.items()}).T

In [None]:
list(trenches_res.values())[30][1]

In [None]:
pitches = trenches_df.loc[:, "label_1.find_trench_lines.hough_1.peak_func.pitch"]

In [None]:
len(pitches)

In [None]:
pitches[pitches > 19]

## New trench detection+segmentation+analysis

#### Config

In [None]:
def filter_trenches(trenches):
    return trenches  # TODO!!!!!
    pitch = 18
    # pitch = 20.9
    # pitch = 24
    if trenches is None:
        return None
    #     good_trenches = trenches[((trenches[('diag', 'find_trench_lines.hough_2.peak_func.pitch')] - 24).abs() <= 1)
    #                               & (trenches[('info','hough_value')] > 90)]
    # TODO: we shouldn't be filtering based on hough_value at all!!
    good_trenches = trenches[
        (
            (
                trenches[("diag", "find_trench_lines.hough_2.peak_func.pitch")] - pitch
            ).abs()
            <= 1
        )
        & (~trenches[("upper_left", "x")].isnull())
    ]
    # TODO: filter based on minimum trench length
    # TODO: filter based on trench peak brightness
    return good_trenches

In [None]:
pixelwise_funcs = {
    "mean": np.mean,
    "min": np.min,
    "max": np.max,
    ("p0.3", "p0.5", "p0.7", "p0.9", "p0.95"): partial(
        np.percentile, q=(30, 50, 70, 90, 95)
    ),
}
trenchwise_funcs = {"sharpness": image.sharpness, **pixelwise_funcs}


def _measurement_func(label_image, intensity_image):
    if label_image is not None:
        eroded_label_image = (
            util.repeat_apply(skimage.morphology.binary_erosion, 2)(label_image != 0)
            * label_image
        )
    if intensity_image is None:
        if label_image is None:
            return None  # can't measure anything
        minlength = label_image.max() + 1
        mask_labelwise_df = pd.DataFrame(
            {
                ("noerode", "size"): np.bincount(label_image.flat, minlength=minlength),
                ("erode2", "size"): np.bincount(
                    eroded_label_image.flat, minlength=minlength
                ),
            }
        )
        mask_labelwise_df.index.name = "label"
        return dict(mask_labelwise=mask_labelwise_df)
    trenchwise_df = workflow.map_frame(trenchwise_funcs, intensity_image)
    res = dict(trenchwise=trenchwise_df)
    if label_image is None:
        return res  # only measure trenchwise
    labelwise = {
        "noerode": workflow.map_frame_over_labels(
            pixelwise_funcs, label_image, intensity_image
        ),
        "erode2": workflow.map_frame_over_labels(
            pixelwise_funcs, eroded_label_image, intensity_image
        ),
    }
    labelwise_df = pd.concat(labelwise, axis=1)
    res["labelwise"] = labelwise_df
    return res  # measure trenchwise and labelwise

In [None]:
def filename_func(
    output_name="out",
    extension=None,
    kind=None,
    name=None,
    filename=None,
    position=None,
):
    path, basename = os.path.split(filename)
    components = [s for s in ("", name, extension) if s is not None]
    if position is None:
        path = [path, output_name, f"{basename}.{kind}" + ".".join(components)]
    else:
        path = [
            path,
            output_name,
            f"{basename}.{kind}",
            "pos{:d}".format(position) + ".".join(components),
        ]
    return os.path.join(*path)

In [None]:
client.restart()

In [None]:
util.apply_map_futures(
    client.gather, all_analysis_futures, predicate=lambda x: x.status == "error"
)