In [None]:
import numpy as np
import pandas as pd
import nd2reader
import matplotlib.pyplot as plt
import holoviews as hv
import skimage.filters
import skimage.feature
import scipy.ndimage
import peakutils
from tqdm import tnrange, tqdm_notebook
import dask
import dask.array as da
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from cytoolz import partial, compose
from itertools import repeat
from glob import glob
import cachetools
import numpy_indexed
import pickle
import warnings
import os
from numbers import Integral
from dask.delayed import Delayed

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from segmentation import *
# from util import *
# from matriarch_stub import *
import segmentation
import matriarch_stub

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
hv.extension("bokeh")

# Run

In [None]:
dask.config.config["distributed"]["scheduler"]["allowed-failures"] = 20
# dask.config.config['distributed']['worker']['memory'] = {'target': 0.4,
#                                                         'spill': 0.5,
#                                                         'pause': 0.9,
#                                                         'terminate': 0.95}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="03:00:00",
    memory="8GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/molecule-counting/log",
    cores=1,
    processes=1,
)
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH=\"/home/jqs1/projects/matriarch\"'])
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

## Run

In [None]:
funcs = {"mean": np.mean}  # ,
#'median': np.median}
#'p0.05': partial(np.percentile, q=5),
#'p0.20': partial(np.percentile, q=20),
#'p0.70': partial(np.percentile, q=70),
#'p0.95': partial(np.percentile, q=95)}

In [None]:
base_filename = "/n/scratch2/jqs1/fidelity/190313"
fluorescence_filenames = (
    glob(os.path.join(base_filename, "fluorescence/*.nd2"))
    + glob("/n/scratch2/jqs1/fidelity/190325/fluorescence/*/*.nd2")
    + glob("/n/scratch2/jqs1/fidelity/190326/*.nd2")
)
phase_filenames = glob(os.path.join(base_filename, "phase/*_0001.nd2")) + glob(
    "/n/scratch2/jqs1/fidelity/190325/phase/*/*_0001.nd2"
)
sandwich_filenames = glob(os.path.join(base_filename, "sandwich/*_0001.nd2"))

In [None]:
problem_filenames = [
    "/n/scratch2/jqs1/fidelity/190326/190326_mkate_nochlor_wateragarpad_100ms_50pct_laser.nd2",
    "/n/scratch2/jqs1/fidelity/190313/fluorescence/190313_mkate_100ms_50pct_laser.nd2",
    "/n/scratch2/jqs1/fidelity/190313/fluorescence/190313_mkate.nd2",
    "/n/scratch2/jqs1/fidelity/190325/fluorescence/gfp_wateragarpad_100ms_50pct_laser/190325_gfp_wateragarpad_100ms_50pct_laser.nd2",
    "/n/scratch2/jqs1/fidelity/190325/fluorescence/mkate_LBagarpad_100ms_100pct_laser/190325_mkate_lbagarpad_100ms_100pct_laser.nd2",
    "/n/scratch2/jqs1/fidelity/190326/190326_mkate_chlor_wateragarpad_100ms_50pct_laser.nd2",
    "/n/scratch2/jqs1/fidelity/190326/190326_GFP_chlor_wateragarpad_100ms_50pct_laser.nd2",
    "/n/scratch2/jqs1/fidelity/190326/190326_mkate_chlor_wateragarpad_100ms_100pct_laser.nd2",
    "/n/scratch2/jqs1/fidelity/190326/190326_mkate_nochlor_wateragarpad_100ms_100pct_laser.nd2",
    "/n/scratch2/jqs1/fidelity/190326/190326_mkate_nochlor_wateragarpad_100ms_50pct_laser.nd2",
]

In [None]:
dark_frames = segmentation.nd2_to_dask(
    os.path.join(base_filename, "calibration/dark_100ms.nd2"), 0, 0
)
dark_frame = dark_frames.mean(axis=0)
# TODO: hack
# dark_frame = dark_frame.compute()
# dark_frame = client.persist(dark_frame)
# dark_frame = client.scatter(dark_frame, broadcast=True)
dark_frame = dark_frame.to_delayed()[0, 0]

In [None]:
flat_fields = {}
for filename in glob(os.path.join(base_filename, "calibration/*flat*100ms*.nd2")):
    channel = segmentation.get_nd2_reader(filename).metadata["channels"][0]
    flat_field = segmentation.nd2_to_dask(filename, 0, 0).mean(axis=0)
    # TODO: hack
    # flat_field = flat_field.compute()
    # flat_field = client.scatter(flat_field, broadcast=True)
    # flat_field = client.persist(flat_field)
    flat_field = flat_field.to_delayed()[0, 0]
    flat_fields[channel] = flat_field

In [None]:
data_graph = {}
for photobleaching_filename in fluorescence_filenames[:]:
    data_graph[photobleaching_filename] = segmentation.process_file(
        funcs,
        photobleaching_filename,
        dark_frame=dark_frame,
        flat_fields=flat_fields,
        time_slice=slice(7),
    )

for photobleaching_filename in phase_filenames[:]:
    segmentation_filename = photobleaching_filename.replace("_0001.nd2", ".nd2")
    data_graph[segmentation_filename] = segmentation.process_file(
        funcs,
        photobleaching_filename,
        segmentation_filename=segmentation_filename,
        dark_frame=dark_frame,
        flat_fields=flat_fields,
        time_slice=slice(7),
    )

for initial_filename in sandwich_filenames[:]:
    segmentation_filename = initial_filename.replace("_0001.nd2", ".nd2")
    photobleaching_filename = initial_filename.replace("_0001.nd2", "_0002.nd2")
    final_filename = initial_filename.replace("_0001.nd2", "_0003.nd2")
    data_graph[segmentation_filename] = segmentation.process_file(
        funcs,
        photobleaching_filename,
        segmentation_filename=segmentation_filename,
        initial_filename=initial_filename,
        final_filename=final_filename,
        dark_frame=dark_frame,
        flat_fields=flat_fields,
        time_slice=slice(7),
    )

In [None]:
# data_futures = matriarch_stub.recursive_map(lambda x: client.compute(x), data_graph, shortcircuit=Delayed)

In [None]:
# split up computes so we can gather results from multiple workers
# (otherwise the single worker assembling the dict will run out of memory)
data_futures = {k: client.compute(v) for k, v in data_graph.items()}
data_futures["_calibration"] = client.compute(
    {"dark_frame": dark_frame, "flat_fields": flat_fields}
)

## Save data

In [None]:
data = client.gather(data_futures)

## Test

In [None]:
{
    f: {pos: dd["labels"].max() for pos, dd in d.items()}
    for f, d in data.items()
    if f != "_calibration"
}

In [None]:
plt.imshow(
    data[
        "/n/scratch2/jqs1/fidelity/190313/sandwich/190313_SB1_in_GFP_bleaching_in_CFP_100ms_100pct_laser_again_again.nd2"
    ][3]["labels"]
)

In [None]:
segmentation.RevImage(
    segmentation.permute_labels(
        data[
            "/n/scratch2/jqs1/fidelity/190325/fluorescence/gfp_LBagarpad_100ms_50pct_laser/190325_gfp_lbagarpad_100ms_50pct_laser.nd2"
        ][0]["labels"]
    )
)

In [None]:
frame = segmentation.get_nd2_frame(
    "/n/scratch2/jqs1/fidelity/190325/fluorescence/gfp_wateragarpad_100ms_50pct_laser/190325_gfp_wateragarpad_100ms_50pct_laser.nd2",
    8,
    0,
    0,
)

In [None]:
frame2 = frame[::2, ::2]

In [None]:
frame = data[
    "/n/scratch2/jqs1/fidelity/190325/fluorescence/gfp_LBagarpad_100ms_50pct_laser/190325_gfp_lbagarpad_100ms_50pct_laser.nd2"
][0]["segmentation_frame"]

In [None]:
diag = {}
labels = segmentation.segment(frame[:600, :600], diagnostics=diag)

In [None]:
plt.imshow(diag["img"].data)

In [None]:
plt.imshow(diag["img_k1_frangi"].data)

In [None]:
plt.imshow(diag["img_thresh"].data)

In [None]:
plt.imshow(diag["img_k1_frangi"].data[:300, :300])

In [None]:
plt.imshow(frame2)

In [None]:
diag = {}
labels = segmentation.segment(frame[:300, :300], diagnostics=diag)

In [None]:
plt.imshow(scipy.ndimage.minimum_filter(frame, size=10, mode="nearest"))

In [None]:
def normalize_locally(
    img,
    min_size=100,
    max_size=50,
    min_blur=50,
    max_blur=10,
    in_place=False,
    mode="nearest",
    cval=0,
    diagnostics=None,
):
    img = skimage.img_as_float(img, force_copy=(not in_place))
    if diagnostics is not None:
        diagnostics["img"] = RevImage(img)  # TODO: img.copy?
    img_min = scipy.ndimage.minimum_filter(img, size=min_size, mode=mode, cval=cval)
    if diagnostics is not None:
        diagnostics["img_min"] = RevImage(img_min)
    img_min_blurred = skimage.filters.gaussian(img_min, min_blur)
    if diagnostics is not None:
        diagnostics["img_min_blurred"] = RevImage(img_min_blurred)
    img -= img_min_blurred
    if diagnostics is not None:
        diagnostics["img_min_subtracted"] = RevImage(img)
    img_max = scipy.ndimage.maximum_filter(img, size=max_size, mode=mode, cval=cval)
    if diagnostics is not None:
        diagnostics["img_max"] = RevImage(img_max)
    img_max_blurred = skimage.filters.gaussian(img_max, max_blur)
    if diagnostics is not None:
        diagnostics["img_max_blurred"] = RevImage(img_max_blurred)
    img /= img_max_blurred
    if diagnostics is not None:
        diagnostics["img_normalized"] = RevImage(img)
    return img

In [None]:
# get rid of small numbers of very bright pixels

In [None]:
frame_normed = normalize_locally(frame)
plt.imshow(frame_normed[:300, :300])

In [None]:
plt.imshow(frame_normed)

In [None]:
plt.imshow(
    skimage.filters.gaussian(
        scipy.ndimage.minimum_filter(frame, size=100, mode="nearest"), 100
    )
)

In [None]:
plt.imshow(
    skimage.filters.gaussian(
        scipy.ndimage.minimum_filter(frame, size=100, mode="nearest"), 10
    )
)

In [None]:
plt.imshow(
    skimage.filters.gaussian(
        scipy.ndimage.minimum_filter(frame, size=100, mode="nearest"), 50
    )
)

In [None]:
plt.imshow(scipy.ndimage.minimum_filter(frame, size=100, mode="nearest"))

In [None]:
plt.imshow(scipy.ndimage.maximum_filter(frame, size=10, mode="nearest"))

In [None]:
plt.imshow(
    skimage.filters.gaussian(
        scipy.ndimage.minimum_filter(frame, size=10, mode="nearest"), 10
    )
)

In [None]:
plt.imshow(
    skimage.filters.gaussian(
        scipy.ndimage.minimum_filter(frame, size=10, mode="nearest"), 20
    )
)

In [None]:
plt.imshow(
    skimage.filters.gaussian(
        scipy.ndimage.minimum_filter(frame, size=10, mode="nearest"), 5
    )
)

In [None]:
frame - 

In [None]:
plt.imshow(frame[:300, :300])

In [None]:
frame_normed = normalize_locally(frame)
plt.imshow(frame_normed[:300, :300])

In [None]:
plt.imshow(frame_normed)

In [None]:
plt.imshow(diag["img_normalized"].data)

In [None]:
plt.imshow(diag["img_normalized"].data)

In [None]:
labels.max()

In [None]:
%%output size=150
(
    diag["img_k1_frangi"]
    + diag["img_normalized"].redim(z="z2")
    + diag["watershed_labels_permuted"].redim(z="z3")
)

In [None]:
diag.keys()

In [None]:
frame = diag["img"].data[:300, :300]

In [None]:
frame = frame[:300, :300]

In [None]:
mask = frame > skimage.filters.threshold_otsu(frame)
mask_labels = skimage.morphology.label(mask)

In [None]:
img_normed = matriarch_stub.normalize_componentwise(frame, mask_labels)

In [None]:
img_normed.min()

In [None]:
diag["img"].data.min()

In [None]:
diag["img_normalized"].data.min()

In [None]:
plt.imshow(diag["img_normalized"].data[:600, :600])

In [None]:
plt.imshow(diag["img_normalized"].data[:600, :600])

In [None]:
%%output size=200
segmentation.RevImage(
    data[
        "/n/scratch2/jqs1/fidelity/190325/fluorescence/gfp_LBagarpad_100ms_50pct_laser/190325_gfp_lbagarpad_100ms_50pct_laser.nd2"
    ][0]["segmentation_frame"]
)

## Old

In [None]:
data.keys()

In [None]:
with open("190326photobleaching_flatcorr.pickle", "wb") as f:
    pickle.dump(data, f)

In [None]:
# time: 96593s
# 65% map_over_labels, 30% get_nd2_frame
# time: 109885s (3/22)
# 56% map_over_labels 43% get_nd2_frame
# time: 83334s (3/26, only fluor)
# 88% map_over_labels 9% get_nd2_frame

In [None]:
%%output size=200
hv.HoloMap(
    {
        filename: RevImage(permute_labels(data[filename][0]["labels"]))
        for filename in data.keys()
    }
)

# Plotting

In [None]:
with open("190311photobleaching.pickle", "rb") as f:
    data = pickle.load(f)

In [None]:
list(trace_res.keys())

In [None]:
b = trace_res["/n/scratch2/jqs1/fidelity/190311/190311_205a_10ms_laser10pct_001.nd2"]

In [None]:
plt.plot(np.log(b[0]["mean"][10]))

In [None]:
b[1].hist("area", bins=100)

In [None]:
traces = z["/home/jqs1/paulsson/fluorescence/190313_GFP_100ms_50pct_laser.nd2"][0][
    "traces"
]["mean"]

In [None]:
traces = b[0]["mean"]

In [None]:
idxs = np.random.permutation(traces.shape[0])

In [None]:
%%output size=250
curves = [
    {
        "x": np.arange(traces.shape[1]),
        "y": np.log(traces[i] / traces[i, 0]),
        "i": idxs[i],
    }
    for i in range(traces.shape[0] // 10)
]
hv.Contours(curves, vdims=["i"]).options(color_index="i", cmap="Category20")

In [None]:
%%output size=250
hv.Path((np.arange(traces.shape[1]), traces.T))

In [None]:
%%output size=250
hv.Overlay.from_values([hv.Curve(t) for t in traces])

In [None]:
%%output size=250
hv.Image(img) + hv.Image(segment(img))