In [None]:
import numpy as np
import pandas as pd
import nd2reader.exceptions
from nd2reader import ND2Reader
import matplotlib.pyplot as plt
import holoviews as hv
import skimage.filters
import skimage.feature
from scipy.ndimage.filters import percentile_filter
import peakutils
from tqdm import tnrange, tqdm_notebook
import dask
import dask.array as da
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from cytoolz import partial, compose
from itertools import repeat
from glob import glob
from functools import lru_cache
import numpy_indexed
import pickle
import warnings

In [None]:
hv.extension("bokeh")

In [None]:
def hessian_eigenvalues(img):
    I = skimage.filters.gaussian(img, 1.5)
    I_x = skimage.filters.sobel_h(I)
    I_y = skimage.filters.sobel_v(I)
    I_xx = skimage.filters.sobel_h(I_x)
    I_xy = skimage.filters.sobel_v(I_x)
    I_yx = skimage.filters.sobel_h(I_y)
    I_yy = skimage.filters.sobel_v(I_y)
    kappa_1 = (I_xx + I_yy) / 2
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", RuntimeWarning)
        kappa_2 = (np.sqrt((I_xx + I_yy) ** 2 - 4 * (I_xx * I_yy - I_xy * I_yx))) / 2
    k1 = kappa_1 + kappa_2
    k2 = kappa_1 - kappa_2
    k1[np.isnan(k1)] = 0
    k2[np.isnan(k2)] = 0
    return k1, k2

In [None]:
DEFAULT_REGIONPROPS = [
    "area",
    "centroid",
    "eccentricity",
    "min_intensity",
    "mean_intensity",
    "max_intensity",
    "major_axis_length",
    "minor_axis_length",
    "orientation",
    "perimeter",
    "solidity",
    "weighted_centroid",
]


def get_regionprops(label_image, intensity_image, properties=DEFAULT_REGIONPROPS):
    rps = skimage.measure.regionprops(
        label_image, intensity_image, coordinates="rc", cache=False
    )
    if not len(rps):
        return None
    cols = {prop: [getattr(rp, prop) for rp in rps] for prop in properties}
    for col, values in list(cols.items()):
        if isinstance(values[0], tuple):
            del cols[col]
            # TODO: store coordinates as multiindex?
            cols[col + "_x"] = [v[0] for v in values]
            cols[col + "_y"] = [v[1] for v in values]
    df = pd.DataFrame(cols, index=range(1, len(rps) + 1))
    df.index.name = "label"
    return df

In [None]:
def segment(img):
    img_frangi = skimage.filters.frangi(img, scale_range=(0.1, 1.5), scale_step=0.1)
    mask = img_frangi < np.percentile(img_frangi, 90)
    mask = skimage.segmentation.clear_border(mask)
    labels = skimage.measure.label(mask)
    return labels

In [None]:
# get_nd2_reader = lru_cache()(ND2Reader)
get_nd2_reader = ND2Reader


def get_nd2_frame(filename, t):
    return get_nd2_reader(filename).get_frame_2D(t=t)


def nd2_to_futures(client, filename):
    nd2 = get_nd2_reader(filename)
    frames = [client.submit(get_nd2_frame, filename, t) for t in range(nd2.sizes["t"])]
    return frames


def map_over_labels(label_image, intensity_image, func):
    # assumes are consecutive integers 1,...,N
    groups = numpy_indexed.group_by(
        label_image.ravel(), intensity_image.ravel(), reduction=func
    )
    return [g[1] for g in groups]


def process_file(client, filename, col_to_funcs):
    frames = nd2_to_futures(client, filename)
    labels = client.submit(segment, frames[0])
    regionprops = client.submit(get_regionprops, labels, frames[0])
    traces = {
        col: client.submit(
            np.transpose,
            [client.submit(map_over_labels, labels, frame, func) for frame in frames],
        )
        for col, func in col_to_funcs.items()
    }
    return traces, regionprops, labels, frames[0]

# Run

In [None]:
dask.config.config["distributed"]["scheduler"]["allowed-failures"] = 20
dask.config.config["distributed"]["worker"]["memory"] = {
    "target": 0.4,
    "spill": 0.5,
    "pause": 0.9,
    "terminate": 0.95,
}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    memory="8GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/molecule-counting/log",
    cores=1,
    processes=1,
)
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH=\"/home/jqs1/projects/matriarch\"'])
client = Client(cluster)

In [None]:
cluster

In [None]:
client.close()

## Run

In [None]:
filenames = glob("/n/scratch2/jqs1/fidelity/190311/*.nd2")  # [:2]

In [None]:
funcs = {
    "mean": np.mean,
    "median": np.median,
    "p0.05": partial(np.percentile, q=5),
    "p0.20": partial(np.percentile, q=20),
    "p0.70": partial(np.percentile, q=70),
    "p0.95": partial(np.percentile, q=95),
}

In [None]:
a = process_file(client, filenames[0], funcs)

In [None]:
client.gather(a)

In [None]:
progress(a)

In [None]:
%%time
data = {filename: process_file(client, filename, funcs) for filename in filenames}

In [None]:
%%time
trace_res = client.gather(data)

In [None]:
with open("190311photobleaching.pickle", "wb") as f:
    pickle.dump(trace_res, f)

# Plotting

In [None]:
with open("190311photobleaching.pickle", "rb") as f:
    trace_res = pickle.load(f)

In [None]:
list(trace_res.keys())

In [None]:
b = trace_res["/n/scratch2/jqs1/fidelity/190311/190311_205a_10ms_laser10pct_001.nd2"]

In [None]:
plt.plot(np.log(b[0]["mean"][10]))

In [None]:
b[1].hist("area", bins=100)

In [None]:
traces = b[0]["mean"]

In [None]:
idxs = np.random.permutation(traces.shape[0])

In [None]:
%%output size=250
curves = [
    {
        "x": np.arange(traces.shape[1]),
        "y": np.log(traces[i] / traces[i, 0]),
        "i": idxs[i],
    }
    for i in range(traces.shape[0] // 10)
]
hv.Contours(curves, vdims=["i"]).options(color_index="i", cmap="Category20")

In [None]:
%%output size=250
hv.Path((np.arange(traces.shape[1]), traces.T))

In [None]:
%%output size=250
hv.Overlay.from_values([hv.Curve(t) for t in traces])

In [None]:
%%output size=250
hv.Image(img) + hv.Image(segment(img))