# Imports

In [None]:
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import zarr
import dask
from dask import delayed
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
import matplotlib.pyplot as plt
import holoviews as hv
from tqdm.auto import tqdm
from functools import partial
import itertools as it
from collections import namedtuple
import nd2reader
import re
import os
from pathlib import Path
import skimage.measure

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import *
import paulssonlab.image_analysis.new as new

In [None]:
#%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Test

In [None]:
nd2 = nd2reader.ND2Reader(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
)

In [None]:
img = nd2.get_frame_2D(v=100, c=1, t=0)

In [None]:
%%time
trenches = new.image.find_trench_bboxes(
    img, peak_func=trench_detection.peaks.find_peaks
)

In [None]:
%%time
trenches3 = new.image.find_trench_bboxes(
    img[:1000, :1000], peak_func=trench_detection.peaks.find_peaks
)

In [None]:
trenches3["widths"].median()

In [None]:
trenches3["widths"].plot.hist(bins=100)

In [None]:
%%time
diag2 = util.tree()
trenches2 = trench_detection.find_trenches(
    img, peak_func=trench_detection.peaks.find_peaks, diagnostics=diag2
)

In [None]:
%%time
diag2 = util.tree()
trenches2 = new.image.find_trench_bboxes(
    img, peak_func=trench_detection.peaks.find_peaks, diagnostics=diag2
)

In [None]:
ui.show_plot_browser(diag2["find_trenches"]["label_4"]);

In [None]:
diag2["find_trenches"]

In [None]:
diag2["bboxes"].opts(frame_width=700, frame_height=400)

In [None]:
trenches2

In [None]:
diag2.keys()

In [None]:
hv.Rectangles((0, 0, 1, 1)).opts(fill_color=None, line_color="red", line_width=1)

In [None]:
top_endpoints = np.vstack((trenches2["top_x"].values, trenches2["top_y"].values)).T
bottom_endpoints = np.vstack(
    (trenches2["bottom_x"].values, trenches2["bottom_y"].values)
).T
top_endpoints.shape

In [None]:
np.hstack((top_endpoints, bottom_endpoints)).shape

In [None]:
trenches2

In [None]:
diag2

In [None]:
ui.show_plot_browser(diag2["label_1"]);

In [None]:
ui.show_plot_browser(diag2["labeling"]["find_trench_lines"]);

In [None]:
trenches2

In [None]:
data = diag2["labeling"]["find_trench_lines"]["hough_2"]["trimmed_profile"].data

In [None]:
freqs, spectrum = scipy.signal.periodogram(
    data.y.values, window="hann", nfft=2**14, scaling="spectrum"
)

In [None]:
hv.Curve((freqs, spectrum))

In [None]:
f, t, Sxx = scipy.signal.spectrogram(
    data.y.values, nfft=2**12, window="hann", scaling="spectrum"
)

In [None]:
f, t, Sxx = scipy.signal.spectrogram(
    data.y.values, nfft=2**12, window="hann", scaling="spectrum", mode="complex"
)

In [None]:
np.abs(Sxx)

In [None]:
hv.QuadMesh((t, f, np.abs(Sxx)))

In [None]:
hv.QuadMesh((t, f, np.real(Sxx)))

In [None]:
plt.pcolormesh(t, f, Sxx, shading="gouraud")

In [None]:
ui.show_plot_browser(diag2["labeling"]);

In [None]:
diag2  # ["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
ui.show_plot_browser(diag2["label_10"]["find_trench_ends"]["image_with_trenches"]);

In [None]:
%%time
diag = util.tree()
trenches = trench_detection.find_trenches(img, diagnostics=diag)

In [None]:
ui.show_plot_browser(diag);

In [None]:
%%time
crops = {}
for i, crop in it.islice(new.image.iter_crops(img, trenches), 10):
    crops[i] = crop
    # mask = trench_segmentation.segment(crop)

In [None]:
plt.imshow(crops[0].T)

In [None]:
plt.imshow(crops[1].T)

In [None]:
plt.imshow(crops[2].T)

In [None]:
plt.imshow(crops[3].T)

In [None]:
plt.imshow(crops[4].T)

In [None]:
plt.imshow(mask)

# Loader

In [None]:
x = send_nd2(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2",
    slices=dict(t=slice(None, 3), v=[14, 25]),
    delayed=False,
)
list(x)[1:10]

In [None]:
x = send_eaton_fish(
    "/home/jqs1/scratch/jqs1/microscopy/220718/FISH/real_run/",
    r"fov=(?P<v>\d+)_config=(?P<c>\w+)_t=(?P<t>\d+)",
)
list(x)[:20]

# Functions

# Handler

In [None]:
segmentation_channel = "RFP-Penta"
trench_detection_channel = segmentation_channel  # channel for trench detection, almost always same as segmentation_channel
measure_channels = ["RFP-Penta", "YFP-DUAL"]
fish_channels = ["RFP-Penta", "Cy5-PENTA", "Cy7"]

In [None]:
import logging


class Pipeline:
    def __init__(self, output_dir):
        self.logger = logging.getLogger("Pipeline")
        self.output_dir = Path(output_dir)
        self.state = {}
        self.array = {}
        self.table = {}

    def delayed(self, func, *args, **kwargs):
        # TODO:
        # log exceptions
        # log warnings (deduplicated, count instances)
        # optionally retry with diag if func takes "diagnostics" argument
        # log benchmarking/profiling? or collect stats, only log outliers (+ call arguments)
        return dask.delayed(func, *args, **kwargs)

def crop_trenches(img, trenches):
    crops = {}
    for i, crop in it.islice(new.image.iter_crops(img, trenches), 10):
        crops[i] = crop
    return crops


def segment_trenches(crops):
    masks = {}
    for i, crop in crops.items():
        masks[i] = trench_segmentation.segment(crop)
    return masks


def measure_crops(label_image, intensity_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            properties=(
                "label",
                "intensity_mean",
            ),
        )
    ).set_index("label")


def measure_mask_crops(label_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            properties=(
                "label",
                "area",
                "axis_major_length",
                "axis_minor_length",
                "orientation",
                "centroid",
            ),
        )
    ).set_index("label")


# TODO: use a namedtuple (or typing.NamedTuple, or dataclass) for keys so that fields are named
def handle_image(pipeline, msg):
    image = msg["image"]
    metadata = msg["metadata"]
    fov_num = metadata["fov_num"]
    t = metadata["t"]
    channel = metadata["channel"]
    raw_key = ("raw", fov_num, t, channel)
    # store raw image (in production, we won't do this, we will only store crops as we do below)
    pipeline.array[raw_key] = image
    # TODO: we need a way to store per-frame metadata and write it to disk
    trenches_key = (
        "trenches",
        fov_num,
    )
    trenches = pipeline.table.get(trenches_key)
    # check if we have done trench detection for this FOV
    if trenches is None and channel == trench_detection_channel:
        # if not, find trenches and save the resulting table
        trenches = pipeline.delayed(new.image.find_trench_bboxes)(image, peak_func=trench_detection.peaks.find_peaks)
        pipeline.table[trenches_key] = trenches
    # this list keeps track of all the raw frames that need to be cropped
    # frames for multiple channels will accumulate in this list until we get a frame for trench_detection_channel
    # if we have already processed such a frame, then keys_to_crop will contain only the current frame (raw_key)
    keys_to_crop = pipeline.state.setdefault(("keys_to_crop", fov_num), [])
    keys_to_crop.append(raw_key)
    # we only can do further processing if we have already detected trenches for this FOV
    if trenches is not None:
        for raw_to_crop in keys_to_crop:
            crop_key = ("crops", *raw_to_crop[1:])
            # save trench crops for every frame in keys_to_crop
            pipeline.array[crop_key] = pipeline.delayed(crop_trenches)(
                pipeline.array[raw_to_crop], trenches
            )
            segmentation_key = ("segmentation", fov_num, t, segmentation_channel)
            segmentation = pipeline.array.get(segmentation_key)
            if segmentation is not None:
                # if we have segmentation masks for this frame, we can immediately measure only this frame
                if crop_key[-1] in measure_channels:
                    keys_to_measure = [crop_key]
                else:
                    keys_to_measure = []
            else:
                # we don't have a segmentation mask yet, so we need to add to the keys_to_measure list
                keys_to_measure = pipeline.state.setdefault(("keys_to_measure", fov_num, t), [])
                if crop_key[-1] in measure_channels:
                    # we want to measure this frame
                    keys_to_measure.append(crop_key)
                if crop_key[-1] == segmentation_channel:
                    # if this frame's channel is the segmentation channel, run segmentation
                    segmentation = pipeline.delayed(segment_trenches)(
                        pipeline.array[crop_key]
                    )
                    pipeline.array[segmentation_key] = segmentation
                    # once we have the segmentation mask, get measurements for the mask
                    pipeline.table[("mask_measurements", *crop_key[1:],)] = pipeline.delayed(
                        measure_mask_crops
                    )(segmentation)
            segmentation = pipeline.array.get(segmentation_key)
            # if we now have the segmentation mask, try measuring all frames in the keys_to_measure list
            if segmentation is not None:
                for crop_to_measure in keys_to_measure:
                    measurements_key = ("measurements", *crop_to_measure[1:])
                    pipeline.table[measurements_key] = pipeline.delayed(measure_crops)(
                        segmentation, pipeline.array[crop_to_measure]
                    )
                pipeline.state.pop(("keys_to_measure", fov_num, t), None)
        pipeline.state.pop(("keys_to_crop", fov_num), None)


def handle_fish_barcode(pipeline, msg):
    pass


# we should pick a name that's better/more intuitive than handle_message
def handle_message(pipeline, msg):
    match msg:
        case {"type": "image", **info}:
            match info:
                case {"image_type": "fish_barcode"}:
                    handle_fish_barcode(pipeline, msg)
                case other:
                    handle_image(pipeline, msg)
        case {"type": "nd2_metadata"}:
            print("got metadata") # TODO
        case {"type": "event", **info}:
            print("event", info)
        case {"type": "done"}:
            print("DONE")
        case _:
            # this exception should be caught, we don't want malformed messages to crash the pipeline
            raise ValueError("cannot handle message", msg)

In [None]:
#n = nd2reader.ND2Reader("/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2")
n = nd2reader.ND2Reader("/home/jqs1/scratch/jqs1/microscopy/220523/220523_library_test_smallfile.nd2")

In [None]:
n.metadata["channels"]

In [None]:
%%time
#filename = "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
filename = "/home/jqs1/scratch/jqs1/microscopy/220523/220523_library_test_smallfile.nd2"
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")
for msg in new.readers.send_nd2(
    filename,
    #slices=dict(v=slice(1), t=slice(1)),
):
    handle_message(pipeline, msg)

In [None]:
pipeline.table

In [None]:
g = pipeline.array[('crops', 0, 0, 'RFP-Penta')]

In [None]:
%%time
g.compute(scheduler="synchronous")

In [None]:
g.visualize()

In [None]:
a = client.compute(g)

In [None]:
a

In [None]:
client.restart()

In [None]:
client.gather(a)

In [None]:
client.cancel(a)

# Run

In [None]:
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")

In [None]:
%%time
for msg in send_nd2(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
):
    handle_message(pipeline, msg)

In [None]:
%%time
for msg in send_eaton_fish("/home/jqs1/scratch/jqs1/microscopy/220718/FISH/real_run/"):
    handle_message(pipeline, msg)

In [None]:
handle_message(pipeline, {"type": "done"})

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

## New trench detection+segmentation+analysis

#### Config

In [None]:
def filename_func(
    extension=None, kind=None, name=None, filename=None, position=None, extra="full"
):
    if kind and extra:
        kind = f"{extra}.{kind}"
    components = [s for s in ("", name, extension) if s is not None]
    if position is None:
        path = [f"{filename}.{kind}" + ".".join(components)]
    else:
        path = [f"{filename}.{kind}", "pos{:d}".format(position) + ".".join(components)]
    return os.path.join(*path)

#### Execute

In [None]:
save_trench_err_futures = {}
all_analysis_futures = {}
save_trenches_futures = {}
save_trench_err_futures = {}

all_trench_bboxes_futures = {}  # TODO: just for debugging

for filename, filename_frames in selected_frames.groupby("filename"):
    # analysis_futures = {}
    trench_bboxes_futures = {}
    trench_err_futures = {}
    for position, frames in filename_frames.groupby("position"):
        key = (filename, position)
        frame_to_segment = frames.loc[IDX[:, :, [segmentation_channel], 0], :]
        trenches_future = client.submit(
            do_find_trenches, *frame_to_segment.index[0], priority=10
        )
        trench_err_futures[key] = client.submit(do_get_trench_err, trenches_future)
        trench_bboxes_future = client.submit(
            do_trenches_to_bboxes, trenches_future, (filename, position), priority=10
        )
        trench_bboxes_futures[key] = trench_bboxes_future
        all_trench_bboxes_futures[key] = trench_bboxes_future
        analysis_future = client.submit(
            do_measure_and_write,
            trench_bboxes_future,
            frames,
            measurement_func=_measurement_func,
            # measurement_func=None,
            # segmentation_func=None,
            measure_channels=measure_channels,
            segmentation_channel=segmentation_channel,
            return_none=True,
            write=True,
            filename_func=filename_func,
        )
        all_analysis_futures[key] = analysis_future
    # save trenches
    trenches_filename = filename_func(
        kind="trenches", extension="parquet", filename=filename
    )
    save_trenches_futures[filename] = client.submit(
        do_save_trenches,
        list(dict(sorted(trench_bboxes_futures.items())).values()),
        trenches_filename,
    )
    trench_errs_filename = filename_func(
        kind="trench_errs", extension="pickle", filename=filename
    )
    save_trench_err_futures[filename] = client.submit(
        do_serialize_to_disk,
        trench_err_futures,
        trench_errs_filename,
    )