# Imports

In [None]:
import warnings
from functools import partial
from importlib import reload
from operator import getitem

import dask
import holoviews as hv
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import streamz
import zarr
from bokeh.models.tools import HoverTool
from cytoolz import compose
from dask import delayed
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from holoviews.streams import Stream, param
from tqdm import tnrange, tqdm, tqdm_notebook

In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
# from processing import *
# from trench_detection import *
# from trench_segmentation import *
# from trench_segmentation.watershed import *
# from util import *
# from ui import *
import common
import trench_detection
import ui
import util

In [None]:
%load_ext line_profiler
hv.extension("bokeh")
%matplotlib inline
tqdm.monitor_interval = 0

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="1:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # interface='ib0',
    memory="64GB",
    # local_directory='/tmp',
    threads=1,
    processes=1,
    # diagnostics_port=('127.0.0.1', 8787),
    env_extra=['export PYTHONPATH="/home/jqs1/projects/matriarch"'],
)
client = Client(cluster)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
cluster.stop_workers(cluster.jobs)

In [None]:
cluster = LocalCluster(
    n_workers=0, ncores=1, processes=True, ip="0.0.0.0"
)  # TODO: ncores ignored
cluster.start_worker(ncores=1)
client = Client(cluster)

In [None]:
?LocalCluster

In [None]:
client

# Functions

# Fix diag

In [None]:
def serialize_holoviews(x):
    print("z")
    header = {}
    frames = [hv.Store.dumps(x)]
    return header, frames


def deserialize_holoviews(header, frames):
    print("h")
    if len(frames) > 1:  # this may be cut up for network reasons
        frame = "".join(frames)
    else:
        frame = frames[0]
    return hv.Store.loads(frame)


from distributed.protocol.serialize import register_serialization

register_serialization(
    hv.element.chart.Scatter, serialize_holoviews, deserialize_holoviews
)

In [None]:
reload(hv.core.dimension)

In [None]:
hv.core.dimension.LabelledData

In [None]:
orig_getstate = hv.core.dimension.LabelledData.__getstate__
orig_setstate = hv.core.dimension.LabelledData.__setstate__

In [None]:
from functools import wraps

# orig_getstate = hv.Dimension.__getstate__
# orig_setstate = hv.Dimension.__setstate__


@wraps(orig_getstate)
def new_getstate(self):
    hv.Store.save_option_state = True
    val = orig_getstate(self)
    # hv.Store.save_option_state = False
    return val


# @wraps(orig_setstate)
def new_setstate(self, d):
    print("NEW SETSTATE")
    hv.Store.load_counter_offset = hv.StoreOptions.id_offset()
    orig_setstate(self, d)
    # hv.Store.load_counter_offset = None


# hv.core.dimension.LabelledData.__getstate__ = new_getstate
# hv.core.dimension.LabelledData.__getstate__ = orig_getstate
hv.core.dimension.LabelledData.__setstate__ = new_setstate
# hv.core.dimension.LabelledData.__setstate__ = orig_setstate

In [None]:
hv.Store.save_option_state = True

In [None]:
a = hv.Scatter(np.random.random(10)).options(
    color="yellow", size=10, backend="bokeh", clone=False
)

In [None]:
a

In [None]:
import pickle

In [None]:
hv.__file__

In [None]:
b = pickle.loads(pickle.dumps(a))
b

In [None]:
b.id = 3

In [None]:
b

In [None]:
b.__dict__

In [None]:
z = pickle.dumps(a)
hv.Store._custom_options["bokeh"] = {}
pickle.loads(z)

In [None]:
z = hv.Store.dumps(a)
hv.Store._custom_options["bokeh"] = {}
hv.Store.loads(z)

In [None]:
hv.Store._custom_options

In [None]:
import pickle

In [None]:
z = pickle.dumps(a)

In [None]:
hv.Store._custom_options["bokeh"] = {}

In [None]:
hv.Store.load_counter_offset = hv.StoreOptions.id_offset()
pickle.loads(z)

In [None]:
import common

In [None]:
reload(common)

In [None]:
hv.Dimension.__setstate__

In [None]:
def test():
    import common

    # hv.extension('bokeh')
    # hv.notebook_extension('bokeh')
    # hv.Store.save_option_state = True
    val = hv.Scatter(np.random.random(10)).options(color="green", size=5)
    # import pickle
    # return pickle.dumps(val)
    return val


list(client.run(test).values())[0]  # ['tcp://10.120.17.16:34432']

In [None]:
import pickle

In [None]:
pickle.loads(list(client.run(test).values())[0])

In [None]:
list(client.run(test).values())[0]

In [None]:
hv.Store.load_counter_offset = hv.StoreOptions.id_offset()

In [None]:
hv.Store._custom_options

In [None]:
from dask.distributed import protocol  # as s#.typename(type(a))

In [None]:
protocol.serialize

In [None]:
type(a).__module__

In [None]:
a.print_param_defaults()

In [None]:
a.get_param_values()

In [None]:
?hv.Store.transfer_options

In [None]:
a.

In [None]:
a

In [None]:
?a.options

In [None]:
b = a.options(color="red", size=10)
b

In [None]:
import pickle

In [None]:
hv.Store.loads(hv.Store.dumps(b))

In [None]:
client.close()

In [None]:
client = Client(cluster)

In [None]:
reload(trench_detection)

In [None]:
diag = util.tree()
_ = trench_detection.get_trenches(
    partial(get_frame, v=10, c=2)(nd2_filenames[0]), diagnostics=diag
)

In [None]:
diag["label_1"]["trench_anchors"]["trench_anchor_profile"]

# Loading data

In [None]:
nd2_filenames = [
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2",
]

In [None]:
nd2_sizes = {
    filename: nd2reader.ND2Reader(filename).sizes for filename in nd2_filenames
}

In [None]:
nd2s = {
    filename: client.submit(nd2reader.ND2Reader, filename, memmap=False)
    for filename in nd2_filenames
}
# nd2s = {filename: nd2reader.ND2Reader(filename, memmap=False) for filename in nd2_filenames}

In [None]:
# nd2_sizes = util.gather_futures(client, util.map_futures(partial(client.submit, lambda nd2: nd2.sizes), nd2s))

In [None]:
def do_reload():
    from importlib import reload

    reload(util)
    reload(trench_detection)


client.run(do_reload)

In [None]:
reload(trench_detection)

In [None]:
def get_frame(nd2_filename=None, v=None, c=None):
    return nd2reader.ND2Reader(nd2_filename).get_frame_2D(v=v, c=c, memmap=False)

In [None]:
a = client.submit(
    compose(
        util.wrap_diagnostics(trench_detection.get_trenches),
        lambda x: x.get_frame_2D(v=0, c=2, memmap=False),
    ),
    nd2s[nd2_filenames[0]],
)

In [None]:
z = a.result()[1]

In [None]:
ui.display_plot_browser(z)

In [None]:
ui.display_plot_browser(a.result()[1])

In [None]:
client.submit(
    compose(
        util.wrap_diagnostics(trench_detection.get_trenches),
        lambda x: ND2Reader(x).get_frame_2D(v=0, c=2, memmap=False),
    ),
    nd2_filenames[0],
)

In [None]:
_.result()

In [None]:
trench_data = {
    filename: {
        v: client.submit(
            util.wrap_diagnostics(trench_detection.get_trenches),
            client.submit(partial(get_frame, v=v, c=sizes["c"] - 1), filename),
        )
        for v in range(2)
    }  # sizes['v']//50)}
    for filename, nd2, sizes in util.zip_dicts(nd2s, nd2_sizes)
}

In [None]:
trench_data = {
    filename: {
        v: client.submit(
            compose(
                util.wrap_diagnostics(trench_detection.get_trenches),
                lambda x: x.get_frame_2D(v=v, c=sizes["c"] - 1, memmap=False),
            ),
            nd2,
        )
        for v in range(sizes["v"] // 10)
    }
    for filename, nd2, sizes in util.zip_dicts(nd2s, nd2_sizes)
}

In [None]:
trench_data["/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2"][4]

In [None]:
client.cancel(trench_data)

In [None]:
progress(trench_data)

In [None]:
res = trench_data["/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2"][0].result()[1]

In [None]:
a = res["label_1"]["trench_anchors"]["trench_anchor_profile"]

In [None]:
a

In [None]:
a.Scatter.I.

In [None]:
a.Scatter.I.options(color="red", size=5)

In [None]:
ui.display_plot_browser(
    trench_data["/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2"][14].result()[1]
)

In [None]:
import traceback
from traceback import print_exception, print_tb

In [None]:
traceback.extract_tb(t)

# Scratch

In [None]:
def load_nd2(filename):
    nd2 = nd2reader.ND2Reader(filename)
    diag = tree()
    return get_trenches(nd2[1, 30], diagnostics=diag_pos[pos])

In [None]:
trench_positions = {}

# Old

In [None]:
get_trenches(root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos])

In [None]:
diag = tree()
_ = get_trenches(root_group["raw"][str(pos)][0, 1], diagnostics=diag)

In [None]:
def f(img_stack):
    ary = np.stack(
        [
            segment_trench(img_stack[t], diagnostics=None)
            for t in range(img_stack.shape[0])
        ],
        axis=0,
    )
    ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_seg_masks = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

In [None]:
def f(img_stack):
    return pd.Series(np.percentile(img_stack, 95, axis=(1, 2)))
    # return pd.Series(np.max(img_stack, axis=(1,2)))


trench_traces_all = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(100),
)