In [None]:
import numpy as np
import pandas as pd
import nd2reader.exceptions
from nd2reader import ND2Reader
import matplotlib.pyplot as plt
import holoviews as hv
import skimage.filters
import skimage.feature
import scipy.ndimage
import peakutils
from tqdm import tnrange, tqdm_notebook
import dask
import dask.array as da
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from cytoolz import partial, compose
from itertools import repeat
from glob import glob
import cachetools
import numpy_indexed
import pickle
import warnings

In [None]:
hv.extension("bokeh")

In [None]:
def RevImage(img, **kwargs):
    return _RevImage(hv.Image, img, **kwargs)


def _RevImage(cls, img, **kwargs):
    return cls(img[::-1], bounds=(0, 0, img.shape[1], img.shape[0])).options(
        invert_yaxis=True
    )

In [None]:
def hessian_eigenvalues(img):
    I = skimage.filters.gaussian(img, 1.5)
    I_x = skimage.filters.sobel_h(I)
    I_y = skimage.filters.sobel_v(I)
    I_xx = skimage.filters.sobel_h(I_x)
    I_xy = skimage.filters.sobel_v(I_x)
    I_yx = skimage.filters.sobel_h(I_y)
    I_yy = skimage.filters.sobel_v(I_y)
    kappa_1 = (I_xx + I_yy) / 2
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", RuntimeWarning)
        kappa_2 = (np.sqrt((I_xx + I_yy) ** 2 - 4 * (I_xx * I_yy - I_xy * I_yx))) / 2
    k1 = kappa_1 + kappa_2
    k2 = kappa_1 - kappa_2
    k1[np.isnan(k1)] = 0
    k2[np.isnan(k2)] = 0
    return k1, k2


from cytoolz import reduce


def repeat_apply(func, n):
    if n <= 0:
        return lambda x: x
    return reduce(lambda f1, f2: compose(f1, f2), [func] * n)

In [None]:
DEFAULT_REGIONPROPS = [
    "area",
    "centroid",
    "eccentricity",
    "min_intensity",
    "mean_intensity",
    "max_intensity",
    "major_axis_length",
    "minor_axis_length",
    "orientation",
    "perimeter",
    "solidity",
    "weighted_centroid",
]


def get_regionprops(label_image, intensity_image, properties=DEFAULT_REGIONPROPS):
    rps = skimage.measure.regionprops(
        label_image, intensity_image, coordinates="rc", cache=False
    )
    if not len(rps):
        return None
    cols = {prop: [getattr(rp, prop) for rp in rps] for prop in properties}
    for col, values in list(cols.items()):
        if isinstance(values[0], tuple):
            del cols[col]
            # TODO: store coordinates as multiindex?
            cols[col + "_x"] = [v[0] for v in values]
            cols[col + "_y"] = [v[1] for v in values]
    df = pd.DataFrame(cols, index=range(1, len(rps) + 1))
    df.index.name = "label"
    return df

In [None]:
def segment(img):
    img_frangi = skimage.filters.frangi(img, scale_range=(0.1, 1.5), scale_step=0.1)
    mask = img_frangi < np.percentile(img_frangi, 90)
    mask = skimage.segmentation.clear_border(mask)
    labels = skimage.measure.label(mask)
    return labels

In [None]:
ND2READER_CACHE = cachetools.LFUCache(maxsize=48)


def _get_nd2_reader(filename, memmap=False, **kwargs):
    return nd2reader.ND2Reader(filename, memmap=memmap, **kwargs)


# get_nd2_reader = cachetools.cached(cache=ND2READER_CACHE)(_get_nd2_reader)
get_nd2_reader = _get_nd2_reader


def get_nd2_frame(filename, position, channel, t, memmap=False):
    reader = get_nd2_reader(filename, memmap=memmap)
    channel_idx = reader.metadata["channels"].index(channel)
    ary = reader.get_frame_2D(v=position, c=channel_idx, t=t, memmap=memmap)
    return ary


def nd2_to_futures(client, filename):
    nd2 = get_nd2_reader(filename)
    frames = [client.submit(get_nd2_frame, filename, t) for t in range(nd2.sizes["t"])]
    return frames


def map_over_labels(label_image, intensity_image, func):
    # assumes are consecutive integers 1,...,N
    groups = numpy_indexed.group_by(
        label_image.ravel(), intensity_image.ravel(), reduction=func
    )
    return [g[1] for g in groups]


def process_file(client, filename, col_to_funcs):
    frames = nd2_to_futures(client, filename)
    labels = client.submit(segment, frames[0])
    regionprops = client.submit(get_regionprops, labels, frames[0])
    traces = {
        col: client.submit(
            np.transpose,
            [client.submit(map_over_labels, labels, frame, func) for frame in frames],
        )
        for col, func in col_to_funcs.items()
    }
    return traces, regionprops, labels, frames[0]

# Run

In [None]:
dask.config.config["distributed"]["scheduler"]["allowed-failures"] = 20
# dask.config.config['distributed']['worker']['memory'] = {'target': 0.4,
#                                                         'spill': 0.5,
#                                                         'pause': 0.9,
#                                                         'terminate': 0.95}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="03:00:00",
    memory="8GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/molecule-counting/log",
    cores=1,
    processes=1,
)
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH=\"/home/jqs1/projects/matriarch\"'])
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

# Test

In [None]:
bf = ND2Reader(
    "/home/jqs1/paulsson/190313_SB1_in_GFP_bleaching_in_CFP_100ms_100pct_laser_again_again.nd2"
).get_frame_2D()

In [None]:
fitc = nd2_to_dask(
    "/home/jqs1/paulsson/190313_SB1_in_GFP_bleaching_in_CFP_100ms_100pct_laser_again_again_0001.nd2",
    0,
    "FITC",
)

In [None]:
fitc_frame = fitc.mean(axis=0).compute()

In [None]:
img = bf[:600, :600]
# img = img - img.min()
# img = img.max() - img
# img = img.astype(np.float32)
# img = img - np.percentile(img, 30)
# img = skimage.transform.pyramid_expand(img, upscale=2, multichannel=False)
mask = img < skimage.filters.threshold_otsu(img)
img_labels = skimage.morphology.label(mask)

In [None]:
def normalize_componentwise(
    img, img_labels, label_index=None, dilation=5, in_place=False, dtype=np.float32
):
    if not in_place:
        img = img.astype(dtype).copy()
    if label_index is None:
        label_index = np.unique(img_labels)
    maxes = ndi.maximum(img, labels=img_labels, index=label_index)
    print(maxes)
    # if weighted:
    #    median = np.median(maxes[1:])
    img_labels = repeat_apply(skimage.morphology.dilation, dilation)(img_labels)
    # img[img_labels == 0] = 0
    # img[img_labels == 0] /= np.median(maxes[1:])
    # for idx, label in enumerate(label_index):
    #    mask = img_labels == label
    #    img[mask] /= np.median(maxes)#maxes[idx]
    img /= np.median(maxes[1:])
    img[img_labels == 1] /= np.median(maxes[1:])
    return img

In [None]:
img2 = normalize_componentwise(img, img_labels)

In [None]:
np.clip

In [None]:
img2 = img.astype(np.float32).copy()
img2 = np.median(img2) - img2
img2[img2 < 0] = 0
img3 = normalize_componentwise(img2, img_labels)

In [None]:
RevImage(img3)

In [None]:
RevImage(img.max() - img)

In [None]:
# TODO: improved
def normalize_componentwise2(
    img, img_labels, label_index=None, dilation=5, in_place=False, dtype=np.float32
):
    if not in_place:
        img = img.astype(dtype).copy()
    if label_index is None:
        label_index = np.unique(img_labels)
    mins, maxes, _, _ = ndi.extrema(img, labels=img_labels, index=label_index)
    # if weighted:
    #    median = np.median(maxes)
    img_labels = repeat_apply(skimage.morphology.dilation, dilation)(img_labels)
    # img[img_labels == 0] = 0
    for idx, label in enumerate(label_index):
        mask = img_labels == label
        # if weighted:
        #    img[mask] *= min(maxes[idx] / median, 1)
        # else:
        #    img[mask] /= maxes[idx]
        img[mask] -= mins[idx]
        scale = maxes[idx] - mins[idx]
        if scale != 0:
            img[mask] /= scale
        # TODO: does this save memory compared to the following?
        # img[mask] = (img[mask] - mins[idx]) / (maxes[idx] - mins[idx])
    return img

In [None]:
(img - img.min()).mean()

In [None]:
(img.min(), img.mean(), np.median(img), img.max())

In [None]:
plt.imshow(img)

In [None]:
img = normalize_componentwise(img, img_labels)

In [None]:
%%time
img_frangi = skimage.filters.frangi(img, scale_range=(0.1, 1.5), scale_step=0.1)

In [None]:
%%time
coords = skimage.feature.peak_local_max(
    img_frangi, threshold_abs=img_frangi.max() * 0.1, min_distance=5
)

In [None]:
%%time
maxes2 = skimage.feature.peak_local_max(
    img_frangi, threshold_abs=img_frangi.max() * 0.1, indices=False
)

In [None]:
temp = ndi.filters.maximum_filter(img_frangi, size=5, mode="constant")

In [None]:
k1, k2 = hessian_eigenvalues(img)
frangi_k1, frangi_k2 = hessian_eigenvalues(img_frangi)

In [None]:
%%time
maxes = (
    maxes2 * mask
)  # skimage.feature.peak_local_max(img_frangi, footprint=mask, indices=False)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img)

In [None]:
%%output size=100
#%%opts Image {+axiswise}
(
    RevImage(img)
    + (
        RevImage(img_frangi).redim(z="z2")
        * hv.Points(coords).options(color="g", marker="+", size=5)
    )
    + RevImage(maxes).redim(z="z3")
    + RevImage(temp).redim(z="z4")
    + RevImage(k1).redim(z="z5")
    + RevImage(k2).redim(z="z6")
    + RevImage(frangi_k1).redim(z="z7")
    + RevImage(frangi_k2).redim(z="z8")
).cols(2)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img[:600, :600])

In [None]:
%%time
img_frangi = skimage.filters.frangi(img, scale_range=(0.1, 1.5), scale_step=0.1)

In [None]:
%%time
coords = skimage.feature.peak_local_max(-img_frangi, min_distance=5)

In [None]:
%%time
maxes = skimage.morphology.local_maxima(img_frangi)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(img_frangi[:600, :600])
plt.plot(coords[:, 1], coords[:, 0], "r.")

In [None]:
mask = img_frangi < np.percentile(img_frangi, 99)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(mask[:600, :600])

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(bf[:600, :600])

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(fitc[0, :300, :300].compute())

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(fitc_frame[:300, :300])

## Run

In [None]:
def nd2_to_dask(filename, position, channel):
    nd2 = get_nd2_reader(filename)
    frame0 = get_nd2_frame(filename, position, channel, 0)
    frames = [
        dask.delayed(get_nd2_frame)(filename, position, channel, t)
        for t in range(nd2.sizes["t"])
    ]
    arrays = [
        da.from_delayed(frame, dtype=frame0.dtype, shape=frame0.shape)
        for frame in frames
    ]
    stack = da.stack(arrays, axis=0)
    return stack

In [None]:
dark_frames = nd2_to_dask(
    "/n/scratch2/jqs1/fidelity/190313/calibration/dark_100ms.nd2", 0, "Dark"
)
dark_frame = dark_frames.mean(axis=0)
dark_std = dark_frames.std(axis=0)

In [None]:
gfp_flat = nd2_to_dask(
    "/n/scratch2/jqs1/fidelity/190313/calibration/190313_flatfield_FITC_cfpslide_100ms_20pct.nd2"
).median(axis=0)
mcherry_flat = nd2_to_dask(
    "/n/scratch2/jqs1/fidelity/190313/calibration/190313_flatfield_mcherry_gfpslide_100ms_000.nd2"
).median(axis=0)
mcherry_flat_lowpower = nd2_to_dask(
    "/n/scratch2/jqs1/fidelity/190313/calibration/190313_flatfield_mcherry_30ms_0.35pct_laser.nd2"
).median(axis=0)

In [None]:
flats = {"GFP": gfp_flat, "MCHERRY": mcherry_flat}

In [None]:
filenames = glob("/n/scratch2/jqs1/fidelity/190313/*.nd2")  # [:2]

In [None]:
for filename in filenames:
    

In [None]:
funcs = {
    "mean": np.mean,
    "median": np.median,
    "p0.05": partial(np.percentile, q=5),
    "p0.20": partial(np.percentile, q=20),
    "p0.70": partial(np.percentile, q=70),
    "p0.95": partial(np.percentile, q=95),
}

In [None]:
a = process_file(client, filenames[0], funcs)

In [None]:
client.gather(a)

In [None]:
progress(a)

In [None]:
%%time
data = {filename: process_file(client, filename, funcs) for filename in filenames}

In [None]:
%%time
trace_res = client.gather(data)

In [None]:
with open("190311photobleaching.pickle", "wb") as f:
    pickle.dump(trace_res, f)

# Plotting

In [None]:
with open("190311photobleaching.pickle", "rb") as f:
    trace_res = pickle.load(f)

In [None]:
list(trace_res.keys())

In [None]:
b = trace_res["/n/scratch2/jqs1/fidelity/190311/190311_205a_10ms_laser10pct_001.nd2"]

In [None]:
plt.plot(np.log(b[0]["mean"][10]))

In [None]:
b[1].hist("area", bins=100)

In [None]:
traces = b[0]["mean"]

In [None]:
idxs = np.random.permutation(traces.shape[0])

In [None]:
%%output size=250
curves = [
    {
        "x": np.arange(traces.shape[1]),
        "y": np.log(traces[i] / traces[i, 0]),
        "i": idxs[i],
    }
    for i in range(traces.shape[0] // 10)
]
hv.Contours(curves, vdims=["i"]).options(color_index="i", cmap="Category20")

In [None]:
%%output size=250
hv.Path((np.arange(traces.shape[1]), traces.T))

In [None]:
%%output size=250
hv.Overlay.from_values([hv.Curve(t) for t in traces])

In [None]:
%%output size=250
hv.Image(img) + hv.Image(segment(img))