# Imports

In [None]:
import numpy as np
import pandas as pd
import zarr
import dask
from dask import delayed
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
import streamz
import holoviews as hv
from holoviews.streams import Stream, param
from bokeh.models.tools import HoverTool
import matplotlib.pyplot as plt
from tqdm import tnrange, tqdm, tqdm_notebook
import warnings
from functools import partial
from cytoolz import compose
from operator import getitem
import nd2reader

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from processing import *
from trench_detection import *
from trench_segmentation import *
from trench_segmentation.watershed import *
from util import *
from ui import *

In [None]:
%load_ext line_profiler
hv.notebook_extension("bokeh")
# renderer = hv.renderer('bokeh')
%matplotlib inline
tqdm.monitor_interval = 0

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="1:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # interface='ib0',
    memory="64GB",
    local_directory="/tmp",
    threads=1,
    processes=1,
    diagnostics_port=("127.0.0.1", 8787),
)
client = Client(cluster)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
cluster.start_workers(4)

# Functions

In [None]:
getitem_r = lambda b, a: getitem(a, b)

# memmap

In [None]:
nd2 = nd2reader.ND2Reader(
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2", memmap=True
)

In [None]:
nd2.sizes

In [None]:
RevImage(nd2._parser._get_raw_image_data(0, 0, 2048, 2048)[1][300:600, :])

In [None]:
RevImage(nd2._parser._get_raw_image_data(0, 2, 2048, 2048, memmap=True)[1][300:600, :])

In [None]:
%%time
for v in range(nd2.sizes["v"] // 50):
    nd2.get_frame_2D(v=v, c=0, memmap=True)[300:600, :]

In [None]:
%%time
for v in range(nd2.sizes["v"] // 50):
    nd2.get_frame_2D(v=v, c=0, memmap=True)[:, 300:600]

In [None]:
%%time
for v in range(nd2.sizes["v"] // 50):
    nd2.get_frame_2D(v=v, c=0, memmap=False)[300:600, :]

In [None]:
%%time
for v in range(nd2.sizes["v"] // 50):
    nd2.get_frame_2D(v=v, c=0, memmap=False)[:, 300:600]

In [None]:
%%time
for v in range(nd2.sizes["v"] // 100):
    nd2._parser._get_raw_image_data(v, 0, 2048, 2048)[1][300:600, :]

In [None]:
%%time
for v in range(nd2.sizes["v"] // 100):
    nd2._parser._get_raw_image_data(v, 0, 2048, 2048, memmap=True)[1][300:600, :]

In [None]:
nd2._parser._get_raw_image_data(0, 0, 2048, 2048)

# Loading data

In [None]:
nd2_filenames = [
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2",
]

In [None]:
nd2s = {
    filename: client.submit(nd2reader.ND2Reader, filename) for filename in nd2_filenames
}

In [None]:
nd2.sizes

In [None]:
trench_positions, trench_diag = transpose_dict(wrap_diagnostics)

In [None]:
{wrap_diagnostics(compose(get_trenches, lambda x: x.get_frame_2D()))}

# Scratch

In [None]:
def load_nd2(filename):
    nd2 = nd2reader.ND2Reader(filename)
    diag = tree()
    return get_trenches(nd2[1, 30], diagnostics=diag_pos[pos])

In [None]:
trench_positions = {}

# Old

In [None]:
get_trenches(root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos])

In [None]:
diag = tree()
_ = get_trenches(root_group["raw"][str(pos)][0, 1], diagnostics=diag)

In [None]:
def f(img_stack):
    ary = np.stack(
        [
            segment_trench(img_stack[t], diagnostics=None)
            for t in range(img_stack.shape[0])
        ],
        axis=0,
    )
    ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_seg_masks = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

In [None]:
def f(img_stack):
    return pd.Series(np.percentile(img_stack, 95, axis=(1, 2)))
    # return pd.Series(np.max(img_stack, axis=(1,2)))


trench_traces_all = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(100),
)