# Imports

In [None]:
import numpy as np
import pandas as pd
import zarr
import dask
from dask import delayed
from dask.distributed import Client, LocalCluster
from dask_jobqueue import SLURMCluster
import streamz
import holoviews as hv
from holoviews.streams import Stream, param
from bokeh.models.tools import HoverTool
import matplotlib.pyplot as plt
from tqdm import tnrange, tqdm, tqdm_notebook
import warnings
from functools import partial
from cytoolz import compose
from operator import getitem
import nd2reader
from importlib import reload

In [None]:
#%load_ext autoreload
#%autoreload 2

In [None]:
# from processing import *
# from trench_detection import *
# from trench_segmentation import *
# from trench_segmentation.watershed import *
# from util import *
# from ui import *
import trench_detection, util

In [None]:
%load_ext line_profiler
hv.extension("bokeh")
%matplotlib inline
tqdm.monitor_interval = 0

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="1:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # interface='ib0',
    memory="64GB",
    # local_directory='/tmp',
    threads=1,
    processes=1,
    diagnostics_port=("127.0.0.1", 8787),
)
client = Client(cluster)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
cluster = LocalCluster(n_workers=0, ncores=1)  # TODO: ncores ignored
cluster.start_worker(ncores=1)
client = Client(cluster)

In [None]:
client

In [None]:
cluster.start_workers(1)

In [None]:
cluster.stop_workers([4])

# Functions

In [None]:
getitem_r = lambda b, a: getitem(a, b)

# Loading data

In [None]:
nd2_filenames = [
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2",
]

In [None]:
nd2s = {
    filename: client.submit(nd2reader.ND2Reader, filename, memmap=False)
    for filename in nd2_filenames
}

In [None]:
nd2 = nd2s[nd2_filenames[0]]

In [None]:
a = client.submit(
    compose(
        trench_detection.get_trenches, lambda x: x.get_frame_2D(v=0, c=2, memmap=False)
    ),
    nd2,
)

In [None]:
a = client.submit(
    compose(
        util.wrap_diagnostics(trench_detection.get_trenches),
        lambda x: x.get_frame_2D(v=0, c=2, memmap=False),
    ),
    nd2,
)

In [None]:
reload(trench_detection)

In [None]:
a.result()

In [None]:
trench_data = {
    filename: {
        v: client.submit(
            compose(
                wrap_diagnostics(get_trenches),
                lambda x: x.get_frame_2D(v=v, c=2, memmap=False),
            ),
            nd2,
        )
        for v in range(2)
    }
    for filename, nd2 in nd2s.items()
}

In [None]:
t = trench_data[nd2_filenames[0]][0].traceback()

In [None]:
from traceback import print_tb, print_exception
import traceback

In [None]:
client.gather(trench_data)

In [None]:
traceback.extract_tb(t)

# Scratch

In [None]:
def load_nd2(filename):
    nd2 = nd2reader.ND2Reader(filename)
    diag = tree()
    return get_trenches(nd2[1, 30], diagnostics=diag_pos[pos])

In [None]:
trench_positions = {}

# Old

In [None]:
get_trenches(root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos])

In [None]:
diag = tree()
_ = get_trenches(root_group["raw"][str(pos)][0, 1], diagnostics=diag)

In [None]:
def f(img_stack):
    ary = np.stack(
        [
            segment_trench(img_stack[t], diagnostics=None)
            for t in range(img_stack.shape[0])
        ],
        axis=0,
    )
    ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_seg_masks = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

In [None]:
def f(img_stack):
    return pd.Series(np.percentile(img_stack, 95, axis=(1, 2)))
    # return pd.Series(np.max(img_stack, axis=(1,2)))


trench_traces_all = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(100),
)