# Imports

In [None]:
import traceback
import warnings
from functools import partial
from importlib import reload
from operator import getitem

import dask
import holoplot.pandas
import holoviews as hv
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import streamz
import zarr
from bokeh.models.tools import HoverTool
from cytoolz import compose, get_in
from dask import delayed
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from holoviews.streams import Stream, param
from tqdm import tnrange, tqdm, tqdm_notebook

In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
# from processing import *
# from trench_detection import *
# from trench_segmentation import *
# from trench_segmentation.watershed import *
# from util import *
# from ui import *
import common
import diagnostics
import metadata
import trench_detection
import ui
import util

In [None]:
%load_ext line_profiler
hv.extension("bokeh")
%matplotlib inline
tqdm.monitor_interval = 0

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="5:00:00",
    # job_extra=['-p transfer'],
    # job_extra=['--cores-per-socket=8'],
    # interface='ib0',
    memory="64GB",
    local_directory="/tmp",
    threads=1,
    processes=1,
    # diagnostics_port=('127.0.0.1', 8787),
    env_extra=['export PYTHONPATH="/home/jqs1/projects/matriarch"'],
)
client = Client(cluster)

In [None]:
cluster._widget().children[1].children[1].children[0].children[0].layout.width = "200px"
cluster

In [None]:
cluster.stop_workers(cluster.jobs)

# Functions

# Debugging

# Loading data

In [None]:
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2', '/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2']#, '/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
# nd2_filenames = ['/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2']
nd2_filenames = [
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr001.nd2",
    "/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2",
    "/n/scratch2/jqs1/fidelity/all/TrErr002_Exp.nd2",
]
# nd2_filenames = ['/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC001.nd2', '/home/jqs1/scratch/fidelity/180518_triplegrowthcurve/PHASE_GC002.nd2']

In [None]:
nd2s = {
    filename: client.submit(nd2reader.ND2Reader, filename, memmap=False)
    for filename in nd2_filenames
}

In [None]:
nd2_sizes = util.gather_futures(
    client, util.map_futures(partial(client.submit, lambda nd2: nd2.sizes), nd2s)
)
nd2_parsed_metadata = util.gather_futures(
    client, util.map_futures(partial(client.submit, lambda nd2: nd2.metadata), nd2s)
)
nd2_metadata = util.gather_futures(
    client, util.map_futures(partial(client.submit, metadata.parse_nd2_metadata), nd2s)
)

In [None]:
nd2_channels = {
    filename: md["channels"] for filename, md in nd2_parsed_metadata.items()
}

In [None]:
def position_dataframe(d):
    df = pd.DataFrame.from_dict(d)
    df.rename(
        columns={
            "dPosName": "position",
            "dPosX": "x",
            "dPosY": "y",
            "dPosZ": "z",
            "dPFSOffset": "pfs_offset",
        },
        inplace=True,
    )
    df = df[["position", "x", "y", "z", "pfs_offset"]]
    return df


nd2_positions = pd.concat(
    {
        filename: position_dataframe(
            [
                p
                for p in get_in(
                    [
                        "image_metadata",
                        "SLxExperiment",
                        "ppNextLevelEx",
                        "",
                        "uLoopPars",
                        "Points",
                        "",
                    ],
                    md,
                )
            ]
        )
        for filename, md in nd2_metadata.items()
    }
)
nd2_positions.set_index("position", append=True, inplace=True)
nd2_positions.index = nd2_positions.index.droplevel(1)
nd2_positions.index.names = ["filename", "position"]

# Reload

In [None]:
client.upload_file("diagnostics.py")

In [None]:
def do_reload():
    from importlib import reload

    import diagnostics
    import trench_detection
    import util

    reload(util)
    reload(trench_detection)
    reload(diagnostics)


client.run(do_reload)

In [None]:
reload(util)

In [None]:
reload(diagnostics)

In [None]:
reload(trench_detection)

# Finding trenches

In [None]:
# get_trenches = util.wrap_diagnostics(trench_detection.get_trenches)
trench_data = {
    filename: {
        nd2_positions.loc[filename]
        .iloc[v]
        .name: {
            channel: {
                t: client.submit(
                    trench_detection.get_trenches_diagnostics,
                    client.submit(
                        lambda x: x.get_frame_2D(
                            t=t, v=v, c=channels.index(channel), memmap=False
                        ),
                        nd2,
                    ),
                )
                for t in range(min(sizes["t"], 50))
            }
            for channel in ("MCHERRY",)
        }
        for v in range(10)
    }
    for filename, nd2, sizes, metadata, channels in util.zip_dicts(
        nd2s, nd2_sizes, nd2_metadata, nd2_channels
    )
}

In [None]:
client.cancel(trench_data)

In [None]:
progress(trench_data)

In [None]:
trench_rows = util.map_futures(
    partial(client.submit, diagnostics.wrapped_diagnostics_to_dataframe), trench_data
)
# trench_rows = util.map_futures(partial(client.submit,
#                                       compose(util.expand_diagnostics_by_label,
#                                               util.diagnostics_to_dataframe,
#                                               partial(util.getitem_r, 1))),
#                               trench_data)

In [None]:
client.cancel(trench_rows)

In [None]:
progress(trench_rows)

In [None]:
client.cancel(trench_rows)

In [None]:
trench_rows_combined = util.gather_futures(client, trench_rows)

In [None]:
trench_df = util.map_collections(
    partial(pd.concat, axis=0), trench_rows_combined, max_level=4
)
trench_df.index = trench_df.index.droplevel(-2)
trench_df.index.names = ["filename", "position", "channel", "t", "trench_set"]

In [None]:
trench_df

In [None]:
bad_periods = ~(
    (trench_df["trench_anchors.periodogram_1.period"] < 25)
    & (trench_df["trench_anchors.periodogram_1.period"] > 23)
)

In [None]:
# FROM: https://stackoverflow.com/questions/23937433/efficiently-joining-two-dataframes-based-on-multiple-levels-of-a-multiindex?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
def multi_join(left, right):
    return pd.merge(
        left.reset_index(), right.reset_index(), on=right.index.names, how="inner"
    ).set_index(left.index.names)

In [None]:
trench_df2 = multi_join(trench_df, nd2_positions)

In [None]:
trench_df2[bad_periods]

In [None]:
trench_df2[~bad_periods].reset_index().holoplot(
    y="trench_anchors.periodogram_1.period",
    x="x",
    by=["filename", "channel"],
    kind="scatter",
)

In [None]:
trench_df2[~bad_periods].reset_index().holoplot(
    y="trench_anchors.periodogram_1.period",
    x="t",
    by=["filename", "channel", "position"],
    kind="scatter",
    legend=False,
)

In [None]:
df.head().reset_index()

In [None]:
idxs = df[bad_periods].index

In [None]:
len(idxs)

In [None]:
frame_futures = [
    client.submit(
        lambda x: x.get_frame_2D(
            v=idx[1], t=idx[2], c=nd2_metadata[idx[0]]["channels"].index("MCHERRY")
        ),
        nd2s[idx[0]],
    )
    for idx in idxs
]

In [None]:
progress(frame_futures)

In [None]:
frames = client.gather(frame_futures)

In [None]:
frames = [
    nd2reader.ND2Reader(idx[0]).get_frame_2D(
        v=idx[1], t=idx[2], c=nd2_metadata[idx[0]]["channels"].index("MCHERRY")
    )
    for idx in idxs
]

In [None]:
hv.HoloMap({str(idx): ui.RevImage(frame) for idx, frame in zip(idxs, frames)})

# Scratch

In [None]:
n = nd2reader.ND2Reader("/n/scratch2/jqs1/fidelity/all/180405_txnerr002.nd2")

In [None]:
import metadata

In [None]:
m = metadata.parse_nd2_metadata(n)

In [None]:
m.keys()

# Old

In [None]:
get_trenches(root_group["raw"][str(pos)][1, 30], diagnostics=diag_pos[pos])

In [None]:
diag = tree()
_ = get_trenches(root_group["raw"][str(pos)][0, 1], diagnostics=diag)

In [None]:
def f(img_stack):
    ary = np.stack(
        [
            segment_trench(img_stack[t], diagnostics=None)
            for t in range(img_stack.shape[0])
        ],
        axis=0,
    )
    ary = zarr.array(ary, compressor=DEFAULT_FRAME_COMPRESSOR)
    return ary


trench_seg_masks = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=1,
    preload=True,
    time_slice=slice(None),
    positions=range(1),
)

In [None]:
def f(img_stack):
    return pd.Series(np.percentile(img_stack, 95, axis=(1, 2)))
    # return pd.Series(np.max(img_stack, axis=(1,2)))


trench_traces_all = positionwise_trenchwise_map(
    root_group["raw"],
    trench_points_pos,
    f,
    channel_slice=2,
    preload=True,
    time_slice=slice(None),
    positions=range(100),
)