In [None]:
import numpy as np
import pandas as pd
import nd2reader
import matplotlib.pyplot as plt
import holoviews as hv
from holoviews.operation.datashader import regrid
import skimage.filters
import skimage.feature
import scipy.ndimage
import peakutils
from tqdm.autonotebook import tqdm
import dask
import dask.array as da
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from cytoolz import partial, compose, juxt
from itertools import repeat
from glob import glob
import cachetools
import numpy_indexed
import pickle
import pyarrow as pa
import warnings
import os
from numbers import Integral
from dask.delayed import Delayed

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from segmentation import *
# from util import *
# from matriarch_stub import *
import segmentation
import matriarch_stub

In [None]:
plt.rcParams["figure.figsize"] = (20, 10)
hv.extension("bokeh")

# Run

In [None]:
# dask.config.config['distributed']['scheduler']['allowed-failures'] = 20
# dask.config.config['distributed']['worker']['memory'] = {'target': 0.4,
#                                                         'spill': 0.5,
#                                                         'pause': 0.9,
#                                                         'terminate': 0.95}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="03:00:00",
    memory="8GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/molecule-counting/log",
    cores=1,
    processes=1,
)
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH=\"/home/jqs1/projects/matriarch\"'])
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

# Test

In [None]:
filenames = [
    "/n/scratch2/jqs1/191312/R-G_RG/RG_100pct_100ms_100pct_100ms.nd2_0011.nd2",
    "/n/scratch2/jqs1/191312/R-G_RG/RG_100pct_100ms_100pct_100ms.nd2_0014.nd2",
    "/n/scratch2/jqs1/191312/R-G_RG/RG_100pct_100ms_100pct_100ms.nd2_0008.nd2",
    "/n/scratch2/jqs1/191312/R-G_GR/GR_100pct_100ms_100pct_100ms.nd2_0016.nd2",
    "/n/scratch2/jqs1/191312/G_GR/GR_100pct_100ms_100pct_100ms.nd2_0019.nd2",
    "/n/scratch2/jqs1/191312/G_GR/GR_100pct_100ms_100pct_100ms.nd2_0007.nd2",
    "/n/scratch2/jqs1/191312/R_RG/RG_100pct_100ms_100pct_100ms.nd2_0021.nd2",
    "/n/scratch2/jqs1/191312/R_RG/RG_100pct_100ms_100pct_100ms.nd2_0001.nd2",
    "/n/scratch2/jqs1/191312/R_RG/RG_100pct_100ms_100pct_100ms.nd2_0009.nd2",
    "/n/scratch2/jqs1/191312/R_RG/RG_100pct_100ms_100pct_100ms.nd2_0005.nd2",
]
# nd2 = matriarch_stub.get_nd2_reader(filename)

In [None]:
nd2 = matriarch_stub.get_nd2_reader(filenames[-1])
frame = nd2.get_frame_2D(v=0, t=0, c=0)

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(np.log(frame))

In [None]:
frame_crop = frame2[700:1300, 400:1500]

# New segmentation

In [None]:
%%time
diag = matriarch_stub.tree()
labels = segmentation.segment(frame, diagnostics=diag)

In [None]:
%%time
labels2 = segmentation.segment(frame, diagnostics=None, blur_sigma=1)
labels3 = segmentation.segment(frame, diagnostics=None, blur_sigma=0)

In [None]:
plt.figure(figsize=(60, 40))
plt.imshow(matriarch_stub.permute_labels(labels))

In [None]:
plt.figure(figsize=(60, 40))
plt.imshow(matriarch_stub.permute_labels(labels2))

In [None]:
plt.figure(figsize=(60, 40))
plt.imshow(matriarch_stub.permute_labels(labels3))

# Autothresholding

In [None]:
frame2 = frame.astype(np.int32)  # skimage.img_as_float32(frame)

In [None]:
frame_flattened = frame2 - matriarch_stub.gaussian_box_approximation(frame2, 100)

In [None]:
frame_flattened.max()

In [None]:
plt.imshow(frame2 > skimage.filters.threshold_li(frame2))

In [None]:
skimage.filters.try_all_threshold(frame_flattened)

In [None]:
plt.imshow(frame_flattened > skimage.filters.threshold_otsu(frame_flattened))

# Segment

In [None]:
%load_ext memory_profiler
%load_ext line_profiler

In [None]:
%lprun -f segmentation.segment segmentation.segment(frame)

In [None]:
%mprun -f segmentation.segment segmentation.segment(frame)

In [None]:
%%time
diag = matriarch_stub.tree()
labels = segmentation.segment(frame, diagnostics=diag)

In [None]:
diag[""]

In [None]:
plt.figure(figsize=(30, 10))
plt.imshow(matriarch_stub.permute_labels(labels))

# Old

In [None]:
diag.keys()

In [None]:
plt.imshow(frame2 > 450)

In [None]:
diag["histogram"]

In [None]:
f = frame2
seed = np.copy(f)
seed[1:-1, 1:-1] = f.max()
mask = f
rec = skimage.morphology.reconstruction(seed, mask, method="erosion")

In [None]:
plt.imshow(rec)

In [None]:
plt.imshow(f > 350)

In [None]:
plt.imshow(matriarch_stub.gaussian_box_approximation(f, 50) > 350)

In [None]:
plt.imshow(rec > scipy.ndimage.gaussian_filter(rec, 50))

# Threshold tests

In [None]:
f_blurred_ndi = scipy.ndimage.gaussian_filter(f, 30)

In [None]:
plt.imshow(f_blurred_ndi)

In [None]:
f_blurred_box = matriarch_stub.gaussian_box_approximation(f, 30, n=10)

In [None]:
plt.imshow(f_blurred_box)

In [None]:
(f_blurred_box - f_blurred_ndi).min()

In [None]:
f_blurred_box.dtype

In [None]:
f_blurred_ndi.dtype

In [None]:
f_blurred_ndi.max()

In [None]:
f_blurred_box.max()

In [None]:
delta = f_blurred_box.astype(np.int_) - f_blurred_ndi.astype(np.int_)

In [None]:
delta.min()

In [None]:
f.astype(np.int_) > f_blurred_ndi.astype(np.int_)

In [None]:
plt.imshow(f.astype(np.int_) > f_blurred_ndi.astype(np.int_))

In [None]:
plt.imshow(delta)

In [None]:
f = frame2
f_blurred = matriarch_stub.gaussian_box_approximation(f, 50)
# f_blurred = np.exp(matriarch_stub.gaussian_box_approximation(np.log(f+1), 50))-1
# f_blurred = scipy.ndimage.gaussian_filter(f, 30)
# f_blurred = skimage.filters.threshold_local(f, 201)
m = matriarch_stub.gaussian_box_approximation(f, 2) > f_blurred
l = scipy.ndimage.label(m)[0]
plt.imshow(matriarch_stub.permute_labels(l))

In [None]:
plt.imshow(matriarch_stub.permute_labels(l))

# Old

In [None]:
diag["threshold"]

In [None]:
diag["threshold_metrics"]

In [None]:
diag["threshold_metrics"].Curve.I.data

In [None]:
img_k1 = matriarch_stub.hessian_eigenvalues(frame_crop)[0]

In [None]:
img_hess = skimage.feature.hessian_matrix_det(frame_crop)

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(img_hess)

In [None]:
plt.imshow(img_k1)

In [None]:
img_k1_frangi = skimage.filters.frangi(
    skimage.img_as_float64(img_k1), sigmas=np.arange(1, 6, 2)
)

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(img_k1_frangi)

In [None]:
plt.imshow(frame_crop)

In [None]:
pd.DataFrame(
    skimage.measure.regionprops_table(
        labels,
        frame,
        properties=(
            "label",
            "area",
            "centroid",
            "max_intensity",
            "solidity",
            "major_axis_length",
            "minor_axis_length",
            "orientation",
        ),
    )
)

In [None]:
from sys import getsizeof

In [None]:
labels.nbytes

In [None]:
frame.nbytes

In [None]:
%load_ext memory_profiler

In [None]:
%mprun -f segmentation.segment segmentation.segment(frame, diagnostics=None)

In [None]:
%mprun -f segmentation.segment segmentation.segment(frame, diagnostics=None)

In [None]:
%prun segmentation.segment(frame, diagnostics=None)

In [None]:
%prun segmentation.segment(frame, diagnostics=None)

In [None]:
%%time
diag = matriarch_stub.tree()
labels = segmentation.segment(frame, diagnostics=diag)

In [None]:
plt.figure(figsize=(30, 10))
plt.imshow(labels)

In [None]:
%%time
diag2 = matriarch_stub.tree()
labels2 = segmentation.segment(frame2, diagnostics=diag2)

In [None]:
diag["threshold"]

In [None]:
seg, _ = scipy.ndimage.label(frame > 400)

In [None]:
%%time
seg = skimage.filters.threshold_local(frame, 201)

In [None]:
np.clip?

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(matriarch_stub.permute_labels(scipy.ndimage.label(frame > seg / 1.01)[0]))

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(matriarch_stub.permute_labels(segs[20]))

In [None]:
%%time
# thresholds = 10**np.linspace(2.5,3.2,10)
# thresholds = 10**np.linspace(2.5,3.2,50)
thresholds = 10 ** np.linspace(2.5, 3.2, 50)
segs = []
sizes = []
for thresh in thresholds:
    seg, _ = scipy.ndimage.label(frame > thresh)
    segs.append(seg)
    sizes.append(np.bincount(seg.flat))
bin_edges = 10 ** np.linspace(
    np.log10(10), np.log10(max([s[1:].max() for s in sizes])), 50
)
hists = [np.histogram(s, bins=bin_edges)[0] for s in sizes]

In [None]:
plt.imshow(segs[20])

In [None]:
def norm(x):
    x = np.array(x)
    return x / np.nanmax(x)

In [None]:
plt.plot(np.diff([s[1:].sum() for s in sizes]))

In [None]:
plt.plot(norm([s[1:].sum() for s in sizes]))

In [None]:
plt.plot(norm([s[s > 100].sum() for s in sizes2]))

In [None]:
plt.plot(norm([s[s > 100].sum() for s in sizes2]))

In [None]:
sizes2 = [s[1:] for s in sizes]

In [None]:
s = sizes2[19]

In [None]:
s[s > 10]

In [None]:
plt.plot(norm([np.median(s[s < frame.size / 10]) for s in sizes]), label=str(n))

In [None]:
for n in 10 ** np.linspace(1, 5, 10):
    plt.plot(norm([np.median(s[s < n]) for s in sizes]), label=str(n))
plt.legend()

In [None]:
plt.plot(norm([np.median(s[s < 100]) for s in sizes]))

In [None]:
sizes

In [None]:
[scipy.stats.median_absolute_deviation(s[10:]) for s in sizes]

In [None]:
plt.plot([np.std(s[s > 100]) for s in sizes2[10:]])

In [None]:
scipy.stats.median_absolute_deviation(sizes2[20])

In [None]:
[np.std(s) for s in sizes2]

In [None]:
# plt.plot(norm([scipy.stats.median_absolute_deviation(s[1:], center=np.mean) for s in sizes[3:]]))
plt.plot(norm([scipy.stats.median_absolute_deviation(s) for s in sizes2]))

In [None]:
# plt.plot(norm([np.median(s[s < 100]) for s in sizes]));
plt.plot(norm([(s > 10).sum() for s in sizes]))
plt.plot(norm([s[1:].max() for s in sizes]))
plt.plot(norm([s[1:].sum() for s in sizes]))

In [None]:
plt.figure(figsize=(30, 10))
for i, hist in list(enumerate(hists))[10:30]:
    plt.plot(hist, label=str(i))
plt.legend()

In [None]:
plt.plot(hists[29])
plt.plot(hists[19])
plt.plot(hists[20])
plt.plot(hists[21])

In [None]:
for i, h in list(enumerate(hists))[18:25]:
    plt.plot(h, label=str(i))
plt.legend()

In [None]:
plt.plot([h.argmax() for h in hists])

In [None]:
sizes2 = [s[2:] for s in sizes]

In [None]:
sizes[3]

In [None]:
plt.plot([np.median(s[s > 10]) for s in sizes2])

In [None]:
for p in (10, 30, 60, 90):
    plt.plot([np.log10(np.percentile(s, p)) for s in sizes2[3:]], label="{}%".format(p))
plt.legend()

In [None]:
plt.plot([(s > 10).sum() for s in sizes2])

In [None]:
# scipy.stats.median_absolute_deviation(h[10:]) * np.median(h[10:])
plt.plot([scipy.stats.median_absolute_deviation(s[s < 1000]) for s in sizes2])

In [None]:
plt.plot([(1 / (np.abs(s[1:] - np.median(s[1:])) + 1) ** 2).sum() for s in sizes])

In [None]:
bin_edges = 10 ** np.linspace(
    np.log10(10), np.log10(max([s[1:].max() for s in sizes])), 20
)
sizes = [np.histogram(s, bins=bin_edges) for s in sizes]

In [None]:
thresholds[9]

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(matriarch_stub.permute_labels(segs[29]))

In [None]:
plt.figure(figsize=(30, 10))
for hist in hists[10:30]:
    plt.plot(np.log(hist[1:]))
plt.xlim(0, 50)

In [None]:
hists[4]

In [None]:
# plt.plot([np.log10(np.median(h)) for h in hists], label='50%')
for p in (10, 30, 60, 90):
    plt.plot(
        [
            np.log10(np.percentile(np.concatenate((h, [0] * 100))[100:], p))
            for h in hists
        ],
        label="{}%".format(p),
    )
plt.legend()

In [None]:
plt.plot([np.log(len(h)) for h in hists])

In [None]:
plt.plot(
    [scipy.stats.median_absolute_deviation(h[10:]) * np.median(h[10:]) for h in hists]
)

In [None]:
plt.plot([scipy.stats.median_absolute_deviation(h[10:]) for h in hists])

In [None]:
plt.plot([h[1:400].sum() / h.sum() for h in hists])

In [None]:
plt.plot([h[10] if len(h) > 10 else 0 for h in hists])

In [None]:
plt.plot([np.std(h) for h in hists])

In [None]:
plt.plot([np.median(h[100:]) for h in hists])

In [None]:
plt.plot([np.median(h[10:]) for h in hists])

In [None]:
plt.plot([1 / np.std(h[10:]) for h in hists])

In [None]:
plt.plot([np.median(h[10:]) / np.std(h[10:]) for h in hists])

In [None]:
plt.figure(figsize=(30, 10))
plt.imshow(segs[17])

In [None]:
hists[1]

In [None]:
hists[2]

In [None]:
hv.Histogram(hists[2])

In [None]:
diag["histogram"]

In [None]:
diag2["histogram"]

In [None]:
diag.keys()

In [None]:
diag["img_k1"].data.dtype

In [None]:
diag["img_k1_frangi"].data.dtype

In [None]:
plt.imshow(diag["watershed_labels_permuted"].data)

In [None]:
(
    regrid(diag["img"])
    + regrid(diag["img_k1"])
    + regrid(diag["img_k1_frangi"])
    + regrid(diag["watershed_labels_permuted"])
)

In [None]:
diag["clean_seeds"].data.dtype

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(matriarch_stub.permute_labels(labels))

In [None]:
img_normalized = diag["img_normalized"].data

In [None]:
img_k1, img_k2 = matriarch_stub.hessian_eigenvalues(frame_crop)

In [None]:
%%time
img_k1_frangi = skimage.filters.frangi(img_k1, sigmas=np.arange(0.1, 1.5, 0.2))

In [None]:
img_k1_frangi.max()

In [None]:
img_k1_frangi.min()

In [None]:
(img_k1.min(), img_k1.max())

In [None]:
(img_k2.min(), img_k2.max())

In [None]:
plt.figure(figsize=(40, 40))
plt.imshow(img_k2)

In [None]:
img_k1_finput = img_k1  # img_k1 - img_k1.min()
img_k1_frangi = skimage.filters.frangi(img_k1_finput, sigmas=np.arange(1, 3, 0.5))
plt.figure(figsize=(40, 40))
plt.imshow(img_k1_frangi)

In [None]:
img_k1_finput = img_k1  # img_k1 - img_k1.min()
img_k1_frangi = skimage.filters.frangi(img_k1_finput, sigmas=np.arange(1, 3, 0.1))
plt.figure(figsize=(40, 40))
plt.imshow(img_k1_frangi)

In [None]:
plt.figure(figsize=(40, 40))
plt.imshow(img_normalized)

In [None]:
diag.keys()

In [None]:
from holoviews.operation.datashader import regrid

In [None]:
# TRY different hessian sigma
# TRY segmenting in phase? (and trench-detection in phase)

In [None]:
%%output size=100
(
    regrid(diag["img"])
    + regrid(diag["img_k1"])
    + regrid(diag["img_k1_frangi"])
    + regrid(diag["clean_seeds"])
)

In [None]:
# measure clean_seeds, size filter, use top 1000 (?) to pick threshold

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(matriarch_stub.permute_labels(labels))

## Run

In [None]:
def segmentation_frame_filter(img):
    return True


def segmentation_labels_filter(labels, img):
    return labels.max() < 20000

In [None]:
base_filename = "/n/scratch2/jqs1"
# filenames = glob(os.path.join(base_filename, '190922/*/*photobleaching*.nd2'))

In [None]:
# TODO: try segmenting everything in phase (to reduce bias of segmenting in different channels)
# G_GR  G_RG  G-R_GR  G-R_RG  R-G_GR  R_GR  R-G_RG  R_RG
seg_channel_to_files = {
    "RFP-PENTA": ["191312/R_RG/*.nd2"],
    #'191312/R_GR/*.nd2'], # missing GFP-PENTA for all but one
    "GFP-PENTA": [
        "191312/G_GR/*.nd2",
        "191312/G_RG/*.nd2",
        "191312/G-R_GR/GR*.nd2",  # missing RFP-PENTA for all but two
        "191312/G-R_RG/*.nd2",
        "191312/R-G_GR/*.nd2",
        "191312/R-G_RG/*.nd2",
    ],
}

In [None]:
segmentation.cluster_nd2_by_positions(
    glob(os.path.join(base_filename, "191312/R-G_GR/*.nd2")), ignored_channels=["BF"]
)

In [None]:
funcs = None  # not used

In [None]:
data_graph = {}
for segmentation_channel, file_patterns in seg_channel_to_files.items():
    for file_pattern in file_patterns:
        data_graph[file_pattern] = {}
        filenames = glob(os.path.join(base_filename, file_pattern))
        clustered_filenames = segmentation.cluster_nd2_by_positions(
            filenames, ignored_channels=["BF"]
        )
        for cluster in clustered_filenames.values():
            segmentation_filename = cluster[segmentation_channel]
            channels = list(set(cluster.keys()) - set(["BF"]))
            d = {}
            for channel in channels:
                d[channel] = segmentation.process_photobleaching_file(
                    funcs,
                    cluster[channel],
                    photobleaching_channel=channel,
                    segmentation_filename=segmentation_filename,
                    segmentation_channel=segmentation_channel,
                    time_slice=slice(None),
                    rechunk=True,
                    segmentation_frame_filter=segmentation_frame_filter,
                    segmentation_labels_filter=segmentation_labels_filter,
                )
            rep = d[channels[0]][0]
            seg_data = {
                "segmentation_filename": segmentation_filename,
                "segmentation_channel": segmentation_channel,
                "segmentation_frame": rep["segmentation_frame"],
                "labels": rep["labels"],
                "traces": {channel: d[channel][0]["traces"] for channel in channels},
            }
            data_graph[file_pattern][segmentation_filename] = seg_data

In [None]:
# split up computes so we can gather results from multiple workers
# (otherwise the single worker assembling the dict will run out of memory)
# TODO: use recursive_map(..., levels=?)
data_futures = {
    k: {k2: client.compute(v2) for k2, v2 in v.items()} for k, v in data_graph.items()
}

## Save data

In [None]:
data = client.gather(data_futures)

In [None]:
filename = "/n/groups/paulsson/jqs1/molecule-counting/200103photobleaching.pickle"
with open(filename, "rb") as f:
    data = pickle.load(f)

In [None]:
filename = "/n/groups/paulsson/jqs1/molecule-counting/200102photobleaching.pickle"
with open(filename, "wb") as f:
    pickle.dump(data, f)

In [None]:
{
    k: {pos: np.asarray(d["labels"]).max() for pos, d in v.items()}
    for k, v in data.items()
    if k[0] != "_"
}

In [None]:
d = data["191312/R_RG/*.nd2"][
    "/n/scratch2/jqs1/191312/R_RG/RG_100pct_100ms_100pct_100ms.nd2_0005.nd2"
]

In [None]:
plt.figure(figsize=(30, 10))
plt.imshow(d["segmentation_frame"])

In [None]:
plt.figure(figsize=(30, 10))
plt.imshow(d["labels"])

In [None]:
plt.figure(figsize=(30, 10))
hist_data = [
    np.asarray(d[0]["segmentation_frame"].flat) for d in list(data.values())[:10]
]
plt.hist(hist_data, bins=200, log=True, stacked=False, fill=False, histtype="step")

In [None]:
d = data[
    "/n/scratch2/jqs1/190922/CFP_photobleaching/CFP_photobleaching_50pct_100ms.nd2_0027.nd2"
][0]
img = d["segmentation_frame"]

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(np.log(img))

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(d["labels"])

In [None]:
hist, bin_edges = np.histogram(img.flat, bins=1024)
idx = np.argmax(hist)
thresh = bin_edges[idx]

In [None]:
plt.plot(np.log(bin_edges[:-1]), np.log(hist))
plt.axvline(np.log(thresh))

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(img > 2 * thresh)

In [None]:
h = np.histogram

In [None]:
img_blurred = matriarch_stub.gaussian_box_approximation(img, 50)
img_highpass = img - img_blurred
plt.figure(figsize=(20, 20))
plt.imshow(np.log(img_blurred))

In [None]:
%%time
img_crop = img  # [500:1500,:500]
diag = matriarch_stub.tree()
seg = segmentation.segment(img_crop, diagnostics=diag)

In [None]:
diag["img_blurred"]

In [None]:
diag["histogram"]

In [None]:
diag["mask"]

In [None]:
plt.figure(figsize=(50, 50))
plt.imshow(seg)

In [None]:
plt.hist(img.flat, bins=100, log=True)

In [None]:
diag["mask"]

In [None]:
plt.plot(d["traces"]["mean"].T)

In [None]:
img.max()

In [None]:
np.median(img)