In [None]:
import numpy as np
import nd2reader
import dask
import dask.array as da
from dask import delayed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
import zarr
import glob
import os
from cytoolz import compose

In [None]:
import segmentation
import matriarch_stub
from matriarch_stub import recursive_sequence_map

In [None]:
def nd2_to_dask(filename, channel, rechunk=True):
    nd2 = segmentation.get_nd2_reader(filename)
    frame0 = segmentation.get_nd2_frame(filename, 0, channel, 0)
    _get_nd2_frame = delayed(segmentation.get_nd2_frame, pure=True)
    ary = [
        [
            da.from_delayed(
                _get_nd2_frame(filename, position, channel, t),
                dtype=frame0.dtype,
                shape=frame0.shape,
            )
            for t in range(nd2.sizes["t"])
        ]
        for position in range(nd2.sizes["v"])
    ]
    ary = recursive_sequence_map(da.stack, ary, max_level=1)
    if rechunk:
        ary = ary.rechunk({0: "auto", 1: "auto"})
    return ary

# Cluster

In [None]:
dask.config.config["distributed"]["scheduler"]["allowed-failures"] = 20
# dask.config.config['distributed']['worker']['memory'] = {'target': 0.4,
#                                                         'spill': 0.5,
#                                                         'pause': 0.9,
#                                                         'terminate': 0.95}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="03:00:00",
    memory="4GB",
    # job_extra=['--exclude=compute-e-16-181,compute-e-16-186'],
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/molecule-counting/log",
    cores=1,
    processes=1,
)
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH=\"/home/jqs1/projects/matriarch\"'])
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(120)

# Run

In [None]:
# TODO: dask workers restarting due to memory (minimal example for github issue)

In [None]:
segment = segmentation.segment
segment_phase = compose(segmentation.segment, segmentation.invert)

In [None]:
filenames = glob.glob("/n/scratch2/jqs1/fidelity/190401/*.nd2")

In [None]:
tasks = []
for filename in filenames[:]:
    channels = segmentation.get_nd2_reader(filename).metadata["channels"]
    # store = zarr.LMDBStore(store_filename)
    # root = zarr.group(store=store, overwrite=False)
    for channel_idx, channel in enumerate(channels):
        frames = nd2_to_dask(filename, channel_idx, rechunk=False)
        if channel == "BF":
            _segment_func = segment_phase
        else:
            _segment_func = segment
        segmented_frames = da.apply_gufunc(
            _segment_func,
            "(i,j)->(i,j)",
            frames,
            output_dtypes=np.uint16,
            vectorize=True,
        )
        # root.require_group(channel)
        tasks.append(
            [
                ary.to_zarr(
                    "/n/groups/paulsson/jqs1/molecule-counting/190410segmentation/{}-{}-{}.zarr".format(
                        os.path.basename(filename).replace(".nd2", ""), channel, kind
                    ),
                    compute=False,
                    overwrite=True,
                    compressor=matriarch_stub.DEFAULT_COMPRESSOR,
                    order=matriarch_stub.DEFAULT_ORDER,
                )
                for kind, ary in {"raw": frames, "segmented": segmented_frames}.items()
            ]
        )

In [None]:
futures = [client.compute(t) for t in tasks]

In [None]:
del futures

In [None]:
client.gather(futures)

In [None]:
client.restart()