In [None]:
import numpy as np
import pandas as pd
import nd2reader.exceptions
from nd2reader import ND2Reader
import matplotlib.pyplot as plt
import holoviews as hv
import skimage.filters
import skimage.feature
from scipy.ndimage.filters import percentile_filter
import peakutils
from tqdm import tnrange, tqdm_notebook
import dask
import dask.array as da
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from cytoolz import sliding_window, juxt, concat, partial
from itertools import repeat
from glob import glob
from functools import lru_cache
import numpy_indexed
import warnings

In [None]:
hv.extension("bokeh")

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="00:30:00",
    memory="8GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/molecule-counting/log",
    cores=1,
    processes=1,
)
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH=\"/home/jqs1/projects/matriarch\"'])
client = Client(cluster)

In [None]:
client

In [None]:
cluster

In [None]:
cluster.adapt(minimum=0, maximum=200)

In [None]:
cluster.scale(100)

In [None]:
# get_nd2_reader = lru_cache()(ND2Reader)
get_nd2_reader = ND2Reader


def get_nd2_frame(filename, t):
    return get_nd2_reader(filename).get_frame_2D(t=t)


def nd2_to_dask_array(filename):
    nd2 = get_nd2_reader(filename)
    frames = [
        dask.delayed(get_nd2_frame, pure=True)(filename, t)
        for t in range(nd2.sizes["t"])
    ]
    arrays = [
        da.from_delayed(frame, dtype=nd2.pixel_type, shape=nd2.frame_shape)
        for frame in frames
    ]
    stack = da.stack(arrays, axis=0)
    return stack


def nd2_to_futures(filename):
    nd2 = get_nd2_reader(filename)
    frames = [client.submit(get_nd2_frame, filename, t) for t in range(nd2.sizes["t"])]
    return frames


def process_file(frames, lag=1):
    funcs = juxt(lambda x: sharpness(x[-1]), rms_diff, rms_diff_normed, shift)
    # funcs = juxt(lambda x: sharpness(x[-1]), rms_diff, rms_diff_normed)
    # funcs = rms_diff
    # funcs = partial(client.submit, funcs)
    lagged_frames = list(sliding_window(lag + 1, concat((repeat(np.nan, lag), frames))))
    return client.map(funcs, lagged_frames)

# Prototyping

In [None]:
def hessian_eigenvalues(img):
    I = skimage.filters.gaussian(img, 1.5)
    I_x = skimage.filters.sobel_h(I)
    I_y = skimage.filters.sobel_v(I)
    I_xx = skimage.filters.sobel_h(I_x)
    I_xy = skimage.filters.sobel_v(I_x)
    I_yx = skimage.filters.sobel_h(I_y)
    I_yy = skimage.filters.sobel_v(I_y)
    kappa_1 = (I_xx + I_yy) / 2
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", RuntimeWarning)
        kappa_2 = (np.sqrt((I_xx + I_yy) ** 2 - 4 * (I_xx * I_yy - I_xy * I_yx))) / 2
    k1 = kappa_1 + kappa_2
    k2 = kappa_1 - kappa_2
    k1[np.isnan(k1)] = 0
    k2[np.isnan(k2)] = 0
    return k1, k2

In [None]:
nd2 = ND2Reader(filenames[15])

In [None]:
img = nd2.get_frame_2D(t=0)

In [None]:
def segment(img):
    img_frangi = skimage.filters.frangi(img, scale_range=(0.1, 1.5), scale_step=0.1)
    mask = img_frangi < np.percentile(img_frangi, 90)
    mask = skimage.segmentation.clear_border(mask)
    labels = skimage.measure.label(mask)
    return labels

In [None]:
%%time
labels = segment(img)

In [None]:
regionprops = skimage.measure.regionprops(labels, img)

In [None]:
df = (
    pd.DataFrame({"label": labels.ravel(), "value": img.ravel()})
    .groupby("label")
    .agg(["mean", "min", "max"])
)

In [None]:
# def map_over_labels(label_image, intensity_image, func):
#     labels = range(0, np.max(np.asarray(label_image))+1)
#     return np.array([func(intensity_image[label_image == label]) for label in labels])

# pd.DataFrame({'label': labels.ravel(), 'value': img.ravel()}).groupby('label').agg(['median'])


def map_over_labels(label_image, intensity_image, func):
    # assumes are consecutive integers 1,...,N
    groups = numpy_indexed.group_by(
        label_image.ravel(), intensity_image.ravel(), reduction=func
    )
    return [g[1] for g in groups]

In [None]:
%%time
frames = [nd2.get_frame_2D(t=t) for t in range(nd2.sizes["t"])]

In [None]:
%%time
traces = np.array([map_over_labels(labels, frame, np.median) for frame in frames]).T

In [None]:
traces.shape

In [None]:
idxs = np.random.permutation(traces.shape[0])

In [None]:
%%output size=250
curves = [
    {"x": np.arange(traces.shape[1]), "y": traces[i] / traces[i, 0], "i": idxs[i]}
    for i in range(traces.shape[0] // 10)
]
hv.Contours(curves, vdims=["i"]).options(color_index="i", cmap="Category20")

In [None]:
%%output size=250
hv.Path((np.arange(traces.shape[1]), traces.T))

In [None]:
%%output size=250
hv.Overlay.from_values([hv.Curve(t) for t in traces])

In [None]:
%%output size=250
hv.Image(img) + hv.Image(segment(img))

# Run

In [None]:
client

In [None]:
filenames = glob("/n/scratch2/jqs1/fidelity/190301/jqs_photobleach*.nd2")

In [None]:
filenames

In [None]:
%%time
data = {filename: process_file(nd2_to_futures(filename)) for filename in filenames}

In [None]:
%%time
res = client.gather(data)
res = {filename: np.array(d).T for filename, d in res.items()}