In [None]:
import numpy as np
import pandas as pd
import nd2reader
import matplotlib.pyplot as plt
import holoviews as hv
from holoviews.operation.datashader import regrid
import skimage.filters
import skimage.feature
import scipy.ndimage
import peakutils
from tqdm import tnrange, tqdm_notebook
import dask
import dask.array as da
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from cytoolz import partial, compose, juxt
from itertools import repeat
from glob import glob
import cachetools
import numpy_indexed
import pickle
import pyarrow as pa
import warnings
import os
from numbers import Integral
from dask.delayed import Delayed

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from segmentation import *
# from util import *
# from matriarch_stub import *
import segmentation
import matriarch_stub

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
hv.extension("bokeh")

# Run

In [None]:
# dask.config.config['distributed']['scheduler']['allowed-failures'] = 20
# dask.config.config['distributed']['worker']['memory'] = {'target': 0.4,
#                                                         'spill': 0.5,
#                                                         'pause': 0.9,
#                                                         'terminate': 0.95}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="03:00:00",
    memory="4GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/molecule-counting/log",
    cores=1,
    processes=1,
)
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH=\"/home/jqs1/projects/matriarch\"'])
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

# Test

In [None]:
# TODO
# 0) make benchmarking notebook
# 1) benchmark manual reduceat vs. npi reduction vs. npi non-reduction vs np.unique/bincount for single frame, image stack
# 2) numba gufunc to make it work on numpy image stacks
# 3) dask gufunc to make it work on dask arrays
# 4) dry run without regionprops
# 5) arbitrary sequence of traces using initial segmentation (replace sandwich)
# 6) named_funcs_as_juxt: decorator to turn {'func1': func1, ('q0.5', 'q0.7'): partial(np.percentile, q=(0.5,0.7))} into a multiple-valued func
# 7) BENCHMARK: try readahead buffering/chunk size
# 8) fix regionprops memory usage
# 9) zarrification of labels

# better group_by, only argsort once per labels img, handle multiple funcs as a juxt, also funcs with multiple return values; uses chunks
# use optimized mean? BENCHMARK
# convert dask arrays to delayed before calling short_circuit_none (otherwise we wait until all frames are in RAM)
# don't process FOV if too many labels
# readahead buffering/benchmark buffer size vs chunk size; also VS non-dask array
# dask array correction
# sandwich -> arbitrary sequence of traces (using same segmentation)
# regionprops (use measure func??)
# TODO: if we zarrify labels, we need to turn back into ndarray before map_over_labels

In [None]:
ary = segmentation.nd2_to_dask(
    "/n/scratch2/jqs1/190922/190922_photobleaching_greens/GFP_photobleaching_100pct_100ms_0001.nd2",
    0,
    "GFP-PENTA",
)

In [None]:
ary

In [None]:
seg_img = ary[0].compute()
labels = segmentation.segment(seg_img)

In [None]:
def aggregate(func, labels, img_stack, skip0=True):
    keys = labels.ravel()
    sorter = np.argsort(keys, kind="mergesort")
    sorted_ = keys[sorter]
    flag = sorted_[:-1] != sorted_[1:]
    slices = np.concatenate(([0], np.flatnonzero(flag) + 1, [keys.size]))
    unique = sorted_[slices[:-1]]
    values = img_stack.reshape((img_stack.shape[0], -1))[:, sorter]
    groups = np.split(values, slices[1:-1], axis=1)
    return {
        key: func(group, axis=1)
        for key, group in zip(unique, groups)
        if key != 0 or not skip0
    }

In [None]:
def aggregate_dask(func, labels, img_stack, skip0=False):
    keys = labels.ravel()
    sorter = np.argsort(keys, kind="mergesort")
    sorted_ = keys[sorter]
    flag = sorted_[:-1] != sorted_[1:]
    slices = np.concatenate(([0], np.flatnonzero(flag) + 1, [keys.size]))
    unique = sorted_[slices[:-1]]
    values = img_stack.reshape((img_stack.shape[0], -1))[:, sorter]

    def f(x):
        groups = np.split(x, slices[1:-1], axis=1)
        # TODO: why is this (commented) so much slower??
        # reductions = [np.mean(x, axis=1) for x in groups]
        # return np.array(reductions).T
        reductions = [func(x, axis=1)[:, np.newaxis] for x in groups]
        return np.hstack(reductions)

    if isinstance(img_stack, dask.array.Array):
        groups = values.map_blocks(
            f, drop_axis=1, new_axis=1, chunks=(values.chunks[0], unique.shape[0])
        )
    else:
        groups = [func(x, axis=1) for x in np.split(values, slices[1:-1], axis=1)]
    return {key: group for key, group in zip(unique, groups) if key != 0 or not skip0}

In [None]:
def multiaggregate_dask(func, labels, img_stack, skip0=False):
    keys = labels.ravel()
    sorter = np.argsort(keys, kind="mergesort")
    sorted_ = keys[sorter]
    flag = sorted_[:-1] != sorted_[1:]
    slices = np.concatenate(([0], np.flatnonzero(flag) + 1, [keys.size]))
    unique = sorted_[slices[:-1]]
    values = img_stack.reshape((img_stack.shape[0], -1))[:, sorter]
    ret = func(np.ones((1,) * values.ndim))

    def f(x):
        groups = np.split(x, slices[1:-1], axis=1)
        reductions = [func(x) for x in groups]
        stack = np.stack(reductions, axis=0)
        return stack

    if isinstance(img_stack, dask.array.Array):
        chunks = (unique.shape[0], *ret.shape[:-1], values.chunks[0])
        new_axis = tuple(range(len(chunks) - 1))
        groups = values.map_blocks(f, drop_axis=1, new_axis=new_axis, chunks=chunks)
    else:
        groups = [func(x) for x in np.split(values, slices[1:-1], axis=1)]
    return {key: group for key, group in zip(unique, groups) if key != 0 or not skip0}

In [None]:
def g(x):
    # return x.mean(axis=1)
    return np.array([x.mean(axis=1), x.sum(axis=1), np.median(x, axis=1)])


z = multiaggregate_dask(g, labels, ary[:10])
z[0].shape

In [None]:
z[0].compute()

In [None]:
_.shape

In [None]:
keys = labels.ravel()
sorter = np.argsort(keys, kind="mergesort")
sorted_ = keys[sorter]
flag = sorted_[:-1] != sorted_[1:]
slices = np.concatenate(([0], np.flatnonzero(flag) + 1, [keys.size]))
unique = sorted_[slices[:-1]]

In [None]:
values = ary.reshape((ary.shape[0], -1))[:, sorter]

In [None]:
v = values[:100]  # .compute()

In [None]:
v.shape

In [None]:
f(v).shape

In [None]:
def f(x):
    groups = np.split(x, slices[1:-1], axis=1)
    # TODO: why is this so much slower??
    # reductions = [np.mean(x, axis=1) for x in groups]
    # return np.array(reductions).T
    # reductions = [np.mean(x, axis=1)[:,np.newaxis] for x in groups]
    reductions = [
        np.array([np.mean(x, axis=1), np.mean(x, axis=1), np.mean(x, axis=1)])
        for x in groups
    ]
    print(">>", reductions[0].shape)
    # val = np.hstack(reductions)
    val = np.stack(reductions, axis=-1)
    return val


# z = v.map_blocks(f, drop_axis=1, new_axis=(0,2), chunks=(v.chunks[0], unique.shape[0]))
# z
f(v).shape

In [None]:
def f(x):
    groups = np.split(x, slices[1:-1], axis=1)
    # TODO: why is this so much slower??
    # reductions = [np.mean(x, axis=1) for x in groups]
    # return np.array(reductions).T
    # reductions = [np.mean(x, axis=1)[:,np.newaxis] for x in groups]
    reductions = [
        np.array([np.mean(x, axis=1), np.mean(x, axis=1), np.mean(x, axis=1)])
        for x in groups
    ]
    print(">>", reductions[0].shape)
    # val = np.hstack(reductions)
    val = np.stack(reductions, axis=-1)
    return val


z = v.map_blocks(
    f, drop_axis=1, new_axis=(0, 2), chunks=(3, v.chunks[0], unique.shape[0])
)
z
# f(v).shape

In [None]:
zz = z.compute()

In [None]:
zz.shape

In [None]:
zz[0]

In [None]:
plt.plot(np.log(zz[0]))

In [None]:
def aggregate_dask(func, labels, img_stack, skip0=True):
    keys = labels.ravel()
    sorter = np.argsort(keys, kind="mergesort")
    sorted_ = keys[sorter]
    flag = sorted_[:-1] != sorted_[1:]
    slices = np.concatenate(([0], np.flatnonzero(flag) + 1, [keys.size]))
    unique = sorted_[slices[:-1]]
    _
    values = img_stack.reshape((img_stack.shape[0], -1))[:, sorter]
    groups = np.split(values, slices[1:-1], axis=1)
    return {
        key: func(group, axis=1)
        for key, group in zip(unique, groups)
        if key != 0 or not skip0
    }

In [None]:
aggregate(np.mean, labels, ary[0:1])

## Run

In [None]:
funcs = {"mean": np.mean}  # ,
#'median': np.median}
#'p0.05': partial(np.percentile, q=5),
#'p0.20': partial(np.percentile, q=20),
#'p0.70': partial(np.percentile, q=70),
#'p0.95': partial(np.percentile, q=95)}

In [None]:
base_filename = "/n/scratch2/jqs1"
# fluorescence_filenames = (glob(os.path.join(base_filename, '190411/Noah_Runs/*.nd2')) +
#                           glob(os.path.join(base_filename, '190507/*.nd2')) +
#                           glob(os.path.join(base_filename, '190508/*.nd2')) +
#                           glob(os.path.join(base_filename, '190514/*.nd2')) +
#                           glob(os.path.join(base_filename, '190515/*.nd2')) +
#                           glob(os.path.join(base_filename, '190516/*.nd2')))
# fluorescence_filenames = glob(os.path.join(base_filename, '190523/*ti5*.nd2'))
# fluorescence_filenames = glob(os.path.join(base_filename, '190401/*.nd2')) + glob(os.path.join(base_filename, '190411/Noah_Runs/*.nd2'))
fluorescence_filenames = glob(
    os.path.join(base_filename, "190922/*/*photobleaching*.nd2")
)
phase_filenames = (
    []
)  # glob(os.path.join(base13_filename, 'phase/*_0001.nd2')) + glob('/n/scratch2/jqs1/fidelity/190325/phase/*/*_0001.nd2')
sandwich_filenames = []  # glob(os.path.join(base13_filename, 'sandwich/*_0001.nd2'))

In [None]:
fluorescence_filenames = fluorescence_filenames[:3]

In [None]:
# dark_frames = segmentation.nd2_to_dask(os.path.join(base13_filename, 'calibration/dark_100ms.nd2'), 0, 0)
# dark_frame = dark_frames.mean(axis=0)
# # TODO: hack
# #dark_frame = dark_frame.compute()
# #dark_frame = client.persist(dark_frame)
# #dark_frame = client.scatter(dark_frame, broadcast=True)
# dark_frame = dark_frame.to_delayed()[0,0]

In [None]:
# flat_fields = {}
# for filename in glob(os.path.join(base13_filename, 'calibration/*flat*100ms*.nd2')):
#     channel = segmentation.get_nd2_reader(filename).metadata['channels'][0]
#     flat_field = segmentation.nd2_to_dask(filename, 0, 0).mean(axis=0)
#     # TODO: hack
#     #flat_field = flat_field.compute()
#     #flat_field = client.scatter(flat_field, broadcast=True)
#     #flat_field = client.persist(flat_field)
#     flat_field = flat_field.to_delayed()[0,0]
#     flat_fields[channel] = flat_field

In [None]:
dark_frame = None
flat_fields = {}

In [None]:
data_graph = {}
for photobleaching_filename in fluorescence_filenames[:]:
    data_graph[photobleaching_filename] = segmentation.process_file(
        funcs, photobleaching_filename, dark_frame=dark_frame, flat_fields=flat_fields
    )

for photobleaching_filename in phase_filenames[:]:
    segmentation_filename = photobleaching_filename.replace("_0001.nd2", ".nd2")
    data_graph[segmentation_filename] = segmentation.process_file(
        funcs,
        photobleaching_filename,
        segmentation_filename=segmentation_filename,
        dark_frame=dark_frame,
        flat_fields=flat_fields,
    )

for initial_filename in sandwich_filenames[:]:
    segmentation_filename = initial_filename.replace("_0001.nd2", ".nd2")
    photobleaching_filename = initial_filename.replace("_0001.nd2", "_0002.nd2")
    final_filename = initial_filename.replace("_0001.nd2", "_0003.nd2")
    data_graph[segmentation_filename] = segmentation.process_file(
        funcs,
        photobleaching_filename,
        segmentation_filename=segmentation_filename,
        initial_filename=initial_filename,
        final_filename=final_filename,
        dark_frame=dark_frame,
        flat_fields=flat_fields,
    )

In [None]:
# split up computes so we can gather results from multiple workers
# (otherwise the single worker assembling the dict will run out of memory)
# TODO: use recursive_map(..., levels=?)
data_futures = {
    k: {k2: client.compute(v2) for k2, v2 in v.items()} for k, v in data_graph.items()
}
data_futures["_calibration"] = client.compute(
    {"dark_frame": dark_frame, "flat_fields": flat_fields}
)

## Save data

In [None]:
data = client.gather(data_futures)

In [None]:
filename = "/n/groups/paulsson/jqs1/molecule-counting/191221photobleaching.pickle"
with open(filename, "wb") as f:
    pickle.dump(data, f)

In [None]:
{
    k: {pos: np.asarray(d["labels"]).max() for pos, d in v.items()}
    for k, v in data.items()
    if k[0] != "_"
}