In [None]:
import numpy as np
import pandas as pd
import nd2reader
import matplotlib.pyplot as plt
import holoviews as hv
from holoviews.operation.datashader import regrid
import skimage.filters
import skimage.feature
import scipy.ndimage
import peakutils
from tqdm.autonotebook import tqdm
import dask
import dask.array as da
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from cytoolz import partial, compose, juxt
from itertools import repeat
from glob import glob
import cachetools
import numpy_indexed
import pickle
import pyarrow as pa
import warnings
import os
from numbers import Integral
from dask.delayed import Delayed

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from segmentation import *
# from util import *
# from matriarch_stub import *
import segmentation
import matriarch_stub

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
hv.extension("bokeh")

# Run

In [None]:
# dask.config.config['distributed']['scheduler']['allowed-failures'] = 20
# dask.config.config['distributed']['worker']['memory'] = {'target': 0.4,
#                                                         'spill': 0.5,
#                                                         'pause': 0.9,
#                                                         'terminate': 0.95}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="03:00:00",
    memory="8GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/molecule-counting/log",
    cores=1,
    processes=1,
)
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH=\"/home/jqs1/projects/matriarch\"'])
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(30)

# Heterogenous cluster

In [None]:
# TODO
# fix regionprops memory usage
# DEBUG: why is segmentation so slow?
# split photobleaching task into multiple sub-tasks (each is a single-threaded .compute() call)
# pin segmentation/regionprops tasks to high-RAM nodes (in heterogenous dask cluster)

# filter by FOCUS (???)

# named_funcs_as_juxt: decorator to turn {'func1': func1, ('q0.5', 'q0.7'): partial(np.percentile, q=(0.5,0.7))} into a multiple-valued func
# zarrification of labels (skip??)

# pass in frame metadata to filter funcs (requires unified metadata representation)

## Run

In [None]:
def segmentation_frame_filter(img):
    return True


def segmentation_labels_filter(labels, img):
    return labels.max() < 20000

In [None]:
base_filename = "/n/scratch2/jqs1"
# filenames = glob(os.path.join(base_filename, '190922/*/*photobleaching*.nd2'))

In [None]:
# TODO: try segmenting everything in phase (to reduce bias of segmenting in different channels)
# G_GR  G_RG  G-R_GR  G-R_RG  R-G_GR  R_GR  R-G_RG  R_RG
seg_channel_to_files = {
    "RFP-PENTA": ["191312/R_RG/*.nd2"],
    #'191312/R_GR/*.nd2'], # missing GFP-PENTA for all but one
    "GFP-PENTA": [
        "191312/G_GR/*.nd2",
        "191312/G_RG/*.nd2",
        "191312/G-R_GR/GR*.nd2",  # missing RFP-PENTA for all but two
        "191312/G-R_RG/*.nd2",
        "191312/R-G_GR/*.nd2",
        "191312/R-G_RG/*.nd2",
    ],
}

In [None]:
segmentation.cluster_nd2_by_positions(
    glob(os.path.join(base_filename, "191312/R-G_GR/*.nd2")), ignored_channels=["BF"]
)

In [None]:
funcs = None  # not used

In [None]:
data_graph = {}
for segmentation_channel, file_patterns in seg_channel_to_files.items():
    for file_pattern in file_patterns:
        data_graph[file_pattern] = {}
        filenames = glob(os.path.join(base_filename, file_pattern))
        clustered_filenames = segmentation.cluster_nd2_by_positions(
            filenames, ignored_channels=["BF"]
        )
        for cluster in clustered_filenames.values():
            segmentation_filename = cluster[segmentation_channel]
            channels = list(set(cluster.keys()) - set(["BF"]))
            d = {}
            for channel in channels:
                d[channel] = segmentation.process_photobleaching_file(
                    funcs,
                    cluster[channel],
                    photobleaching_channel=channel,
                    segmentation_filename=segmentation_filename,
                    segmentation_channel=segmentation_channel,
                    time_slice=slice(None),
                    rechunk=True,
                    segmentation_frame_filter=segmentation_frame_filter,
                    segmentation_labels_filter=segmentation_labels_filter,
                )
            rep = d[channels[0]][0]
            seg_data = {
                "segmentation_filename": segmentation_filename,
                "segmentation_channel": segmentation_channel,
                "segmentation_frame": rep["segmentation_frame"],
                "labels": rep["labels"],
                "traces": {channel: d[channel][0]["traces"] for channel in channels},
            }
            data_graph[file_pattern][segmentation_filename] = seg_data

In [None]:
# split up computes so we can gather results from multiple workers
# (otherwise the single worker assembling the dict will run out of memory)
# TODO: use recursive_map(..., levels=?)
data_futures = {
    k: {k2: client.compute(v2) for k2, v2 in v.items()} for k, v in data_graph.items()
}

## Save data

In [None]:
data = client.gather(data_futures)

In [None]:
filename = "/n/groups/paulsson/jqs1/molecule-counting/200101photobleaching3.pickle"
with open(filename, "wb") as f:
    pickle.dump(data, f)

In [None]:
{
    k: {pos: np.asarray(d["labels"]).max() for pos, d in v.items()}
    for k, v in data.items()
    if k[0] != "_"
}

In [None]:
plt.figure(figsize=(30, 10))
hist_data = [
    np.asarray(d[0]["segmentation_frame"].flat) for d in list(data.values())[:10]
]
plt.hist(hist_data, bins=200, log=True, stacked=False, fill=False, histtype="step")

In [None]:
d = data[
    "/n/scratch2/jqs1/190922/CFP_photobleaching/CFP_photobleaching_50pct_100ms.nd2_0027.nd2"
][0]
img = d["segmentation_frame"]

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(np.log(img))

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(d["labels"])

In [None]:
hist, bin_edges = np.histogram(img.flat, bins=1024)
idx = np.argmax(hist)
thresh = bin_edges[idx]

In [None]:
plt.plot(np.log(bin_edges[:-1]), np.log(hist))
plt.axvline(np.log(thresh))

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(img > 2 * thresh)

In [None]:
h = np.histogram

In [None]:
img_blurred = matriarch_stub.gaussian_box_approximation(img, 50)
img_highpass = img - img_blurred
plt.figure(figsize=(20, 20))
plt.imshow(np.log(img_blurred))

In [None]:
%%time
img_crop = img  # [500:1500,:500]
diag = matriarch_stub.tree()
seg = segmentation.segment(img_crop, diagnostics=diag)

In [None]:
diag["img_blurred"]

In [None]:
diag["histogram"]

In [None]:
diag["mask"]

In [None]:
plt.figure(figsize=(50, 50))
plt.imshow(seg)

In [None]:
plt.hist(img.flat, bins=100, log=True)

In [None]:
diag["mask"]

In [None]:
plt.plot(d["traces"]["mean"].T)

In [None]:
img.max()

In [None]:
np.median(img)