# Imports

In [None]:
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import zarr
import dask
from dask import delayed
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
import matplotlib.pyplot as plt
import holoviews as hv
from tqdm.auto import tqdm
from functools import partial
import itertools as it
from collections import namedtuple
import nd2reader
import re
import os
from pathlib import Path

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import *
import paulssonlab.image_analysis.new as new

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Loader

In [None]:
x = send_nd2(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2",
    slices=dict(t=slice(None, 3), v=[14, 25]),
    delayed=False,
)
list(x)[1:10]

In [None]:
x = send_eaton_fish(
    "/home/jqs1/scratch/jqs1/microscopy/220718/FISH/real_run/",
    r"fov=(?P<v>\d+)_config=(?P<c>\w+)_t=(?P<t>\d+)",
)
list(x)[:20]

# Pipeline

In [None]:
import logging


class Pipeline:
    def __init__(self, output_dir):
        self.logger = logging.getLogger(self.__name__)
        self.logger.basicConfig(level=logging.DEBUG)
        self.output_dir = Path(output_dir)
        self.state = {}

    def delayed(self, func, *args, **kwargs):
        # TODO:
        # log exceptions
        # log warnings (deduplicated, count instances)
        # optionally retry with diag if func takes "diagnostics" argument
        # log benchmarking/profiling? or collect stats, only log outliers (+ call arguments)
        pass

# Functions

In [None]:
pixelwise_funcs = {"mean": np.mean, "sum": np.sum}
# trenchwise_funcs = {"sharpness": image.sharpness, **pixelwise_funcs}
# trenchwise_funcs = {}


def _measurement_func(label_image, intensity_image):
    if intensity_image is None:
        if label_image is None:
            return None  # can't measure anything
        mask_labelwise_df = pd.DataFrame(
            skimage.measure.regionprops_table(
                label_image,
                properties=(
                    "label",
                    "area",
                    "axis_major_length",
                    "axis_minor_length",
                    "orientation",
                    "centroid",
                ),
            ),
        ).set_index("label")
        return dict(mask_labelwise=mask_labelwise_df)
    # trenchwise_df = workflow.map_frame(trenchwise_funcs, intensity_image)
    # res = dict(trenchwise=trenchwise_df)
    res = {}
    if label_image is None:
        return res  # only measure trenchwise
    labelwise_df = workflow.map_frame_over_labels(
        pixelwise_funcs, label_image, intensity_image
    )
    # labelwise_df = pd.DataFrame(
    #     skimage.measure.regionprops_table(
    #         label_image,
    #         intensity_image,
    #         properties=("label", "intensity_mean"),
    #     ),
    # ).set_index("label")
    res["labelwise"] = labelwise_df
    return res  # measure trenchwise and labelwise

In [None]:
def _measure(
    trenches,
    frames,
    measurement_func,
    segmentation_channel=segmentation_channel,
    measure_channels=None,
    segmentation_func=trench_segmentation.watershed.segment,
    include_frame=True,
    frame_bits=8,
    frame_downsample=4,
    filename=None,
    position=None,
):
    frame_transformation = compose(
        processing.zarrify,
        partial(image.quantize, bits=frame_bits),
        partial(image.downsample, factor=frame_downsample),
    )
    trench_crops = processing._get_trench_crops(
        trenches,
        frames,
        include_frame=include_frame,
        frame_transformation=frame_transformation,
        filename=filename,
        position=position,
    )
    res = {}
    segmentation_masks = {}
    measurements = {}
    # segment
    for trench_set, crops_trench_channel_t in trench_crops.items():
        if trench_set == "_frame":
            continue
        for trench_idx, crops_channel_t in crops_trench_channel_t.items():
            for channel, crops_t in crops_channel_t.items():
                for t, crop in crops_t.items():
                    if measure_channels is not None and channel not in measure_channels:
                        continue
                    segmentation_key = (trench_set, trench_idx, segmentation_channel, t)
                    segmentation_mask = segmentation_masks.get(segmentation_key, None)
                    if segmentation_mask is None and segmentation_func is not None:
                        segmentation_mask = segmentation_func(
                            trench_crops[trench_set][trench_idx][segmentation_channel][
                                t
                            ]
                        )
                        segmentation_masks[segmentation_key] = segmentation_mask
                        # measure mask
                        if measurement_func is not None:
                            measurements[
                                ("mask", (trench_set, trench_idx, t))
                            ] = measurement_func(segmentation_mask, None)
                    # measure
                    if measurement_func is not None:
                        measurements[
                            (channel, (trench_set, trench_idx, t))
                        ] = measurement_func(segmentation_mask, crop)
    if measurement_func is not None:
        measurement_dfs = util.map_dict_levels(
            lambda k: (k[1], k[0], *k[2:]), measurements
        )
        for name, dfs in measurement_dfs.items():
            dfs = util.unflatten_dict(dfs)
            if isinstance(util.get_one(dfs, level=2), pd.Series):
                df = pd.concat(
                    {
                        channel: pd.concat(channel_dfs, axis=1).T
                        for channel, channel_dfs in dfs.items()
                    },
                    axis=1,
                )
            else:
                df = pd.concat(
                    {
                        channel: pd.concat(channel_dfs, axis=0)
                        for channel, channel_dfs in dfs.items()
                    },
                    axis=1,
                )
            df.index.names = ["trench_set", "trench", "t", *df.index.names[3:]]
            measurement_dfs[name] = df
        res["measurements"] = measurement_dfs
    images = dict(raw=trench_crops)
    if segmentation_func is not None:
        images["segmentation"] = util.unflatten_dict(segmentation_masks)
    res["images"] = images
    return res

# Test

In [None]:
nd2 = nd2reader.ND2Reader("/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2")

In [None]:
img = nd2.get_frame_2D(v=100, c=1, t=0)

In [None]:
%%time
trenches = new.image.find_trench_bboxes(img, peak_func=trench_detection.peaks.find_peaks)

In [None]:
%%time
trenches3 = new.image.find_trench_bboxes(img[:1000,:1000], peak_func=trench_detection.peaks.find_peaks)

In [None]:
trenches3["widths"].median()

In [None]:
trenches3["widths"].plot.hist(bins=100)

In [None]:
%%time
diag2 = util.tree()
trenches2 = trench_detection.find_trenches(img, peak_func=trench_detection.peaks.find_peaks, diagnostics=diag2)

In [None]:
%%time
diag2 = util.tree()
trenches2 = new.image.find_trench_bboxes(img, peak_func=trench_detection.peaks.find_peaks, diagnostics=diag2)

In [None]:
ui.show_plot_browser(diag2["find_trenches"]["label_4"]);

In [None]:
diag2["find_trenches"]

In [None]:
diag2["bboxes"].opts(frame_width=700,frame_height=400)

In [None]:
trenches2

In [None]:
diag2.keys()

In [None]:
hv.Rectangles((0,0,1,1)).opts(fill_color=None, line_color="red", line_width=1)

In [None]:
top_endpoints = np.vstack((trenches2["top_x"].values, trenches2["top_y"].values)).T
bottom_endpoints = np.vstack((trenches2["bottom_x"].values, trenches2["bottom_y"].values)).T
top_endpoints.shape

In [None]:
np.hstack((top_endpoints, bottom_endpoints)).shape

In [None]:
trenches2

In [None]:
diag2

In [None]:
ui.show_plot_browser(diag2["label_1"]);

In [None]:
ui.show_plot_browser(diag2["labeling"]["find_trench_lines"]);

In [None]:
trenches2

In [None]:
data = diag2["labeling"]["find_trench_lines"]["hough_2"]["trimmed_profile"].data

In [None]:
freqs, spectrum = scipy.signal.periodogram(
    data.y.values, window="hann", nfft=2**14, scaling="spectrum"
)

In [None]:
hv.Curve((freqs,spectrum))

In [None]:
f, t, Sxx = scipy.signal.spectrogram(data.y.values, nfft=2**12, window="hann", scaling="spectrum")

In [None]:
f, t, Sxx = scipy.signal.spectrogram(data.y.values, nfft=2**12, window="hann", scaling="spectrum", mode="complex")

In [None]:
np.abs(Sxx)

In [None]:
hv.QuadMesh((t, f, np.abs(Sxx)))

In [None]:
hv.QuadMesh((t, f, np.real(Sxx)))

In [None]:
plt.pcolormesh(t, f, Sxx, shading='gouraud')

In [None]:
ui.show_plot_browser(diag2["labeling"]);

In [None]:
diag2#["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
ui.show_plot_browser(diag2["label_10"]["find_trench_ends"]["image_with_trenches"]);

In [None]:
%%time
diag = util.tree()
trenches = trench_detection.find_trenches(img, diagnostics=diag)

In [None]:
ui.show_plot_browser(diag);

In [None]:
%%time
crops = {}
for i, crop in it.islice(new.image.iter_crops(img, trenches), 10):
    crops[i] = crop
    #mask = trench_segmentation.segment(crop)

In [None]:
plt.imshow(crops[0].T)

In [None]:
plt.imshow(crops[1].T)

In [None]:
plt.imshow(crops[2].T)

In [None]:
plt.imshow(crops[3].T)

In [None]:
plt.imshow(crops[4].T)

In [None]:
plt.imshow(mask)

# Handler

In [None]:
segmentation_channel = "RFP-Penta"
trench_channel = segmentation_channel # channel for trench detection, almost always same as segmentation_channel
measure_channels = ["RFP-Penta", "GFP-PENTA"]
fish_channels = ["RFP-Penta", "Cy5-PENTA", "Cy7"]

In [None]:
def foo():
    pass

def handle_image(pipeline, msg):
    image = msg["image"]
    {"channel": channel, "fov_num": fov_num, "t": t, **other_metadata} = metadata
    raw_key = ("raw", fov_num, channel)
    pipeline.array[raw_key][t] = image
    # do we have trenches?
    trenches_key = ("trenches", fov_num,)
    trenches = pipeline.table.get(trenches_key)
    need_cropping_key = ("need_cropping", fov_num, channel)
    # TODO: use a namedtuple (or typing.NamedTuple, or dataclass) as the key so that fields are named
    keys_to_crop = pipeline.state.setdefault(need_cropping_key, []).append(raw_key)
    if trenches is None and channel == trench_channel:
        trenches = pipeline.delayed(new.image.find_trench_bboxes)
        pipeline.table[trenches_key] = trenches
    if trenches is not None:
        for key in keys_to_crop:
            # save trench crops
            trench_crops = processing._get_trench_crops(
                trenches,
                frames,
                include_frame=include_frame,
                frame_transformation=frame_transformation,
            )
            # add to keys_to_measure
            # segment
            # measure


def handle_fish_barcode(pipeline, msg):
    pass


# we should pick a name that's better/more intuitive than handle_message
def handle_message(pipeline, msg):
    match msg:
        case {"type": "image", **_}:
            match info:
                case {"image_type": "fish_barcode", **_}:
                    handle_fish_barcode(pipeline, msg)
                case other:
                    handle_image(pipeline, msg)
        case {"type": "event", **info}:
            print("event", info)
        case {"type": "done"}:
            print("DONE")
        case _:
            # this exception should be caught, we don't want malformed messages to crash the pipeline
            raise ValueError("cannot handle message", msg)

In [None]:
handle_message({"type": "img", "imgs": 0, "metadata": 1})

# Test

In [None]:
%%time
for msg in arch.send_nd2(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2",
    slices=dict(v=slice(10))
):
    pass

# Run

In [None]:
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")

In [None]:
%%time
for msg in send_nd2(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
):
    handle_message(pipeline, msg)

In [None]:
%%time
for msg in send_eaton_fish("/home/jqs1/scratch/jqs1/microscopy/220718/FISH/real_run/"):
    handle_message(pipeline, msg)

In [None]:
handle_message(pipeline, {"type": "done"})

# Loading data

In [None]:
# nd2_filenames = ["/home/jqs1/scratch/jqs1/microscopy/211117/211117_long_oscillator.nd2"]
# nd2_filenames = ["/n/standby/hms/sysbio/paulsson/collaborations/Personal_Folders/!!Jacob Quinn Shenker/Standby/180928/CapturedRFP_giant snake.nd2"]
nd2_filenames = ["/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"]

In [None]:
all_frames, metadata = workflow.get_nd2_frame_list(nd2_filenames)
image_limits = workflow.get_filename_image_limits(metadata)

`all_frames` lists each exposure (keyed by filename/position/channel/timepoint). `image_limits` is a dict giving *inclusive* image bounds `((x_min, x_max), (y_min, y_max))` for each input image filename. The reason both of these outputs are keyed by filename (and why `workflow.get_nd2_frame_list` takes a list of images) is that we want to support the use case where image acquisition is stopped and restarted one or more times.

In [None]:
image_limits

In [None]:
all_frames

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="20GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

## New trench detection+segmentation+analysis

#### Config

In [None]:
def filter_trenches(trenches):
    return trenches
    # pitch = 32 # (pixels) here we hard-code the correct pitch
    # # so throw out positions with detected pitch more than 1 pixel away from this
    # # a better way to do this is to look at the median pitch of all positions and use that
    # # as the ground truth instead
    # if trenches is None:
    #     return None
    # good_trenches = trenches[
    #     (
    #         (
    #             trenches[("diag", "find_trench_lines.hough_2.peak_func.pitch")] - pitch
    #         ).abs()
    #         <= 1
    #     )
    #     & (~trenches[("upper_left", "x")].isnull())
    # ]
    # TODO: filter based on minimum trench length
    # return good_trenches

In [None]:
def filename_func(
    extension=None, kind=None, name=None, filename=None, position=None, extra="full"
):
    if kind and extra:
        kind = f"{extra}.{kind}"
    components = [s for s in ("", name, extension) if s is not None]
    if position is None:
        path = [f"{filename}.{kind}" + ".".join(components)]
    else:
        path = [f"{filename}.{kind}", "pos{:d}".format(position) + ".".join(components)]
    return os.path.join(*path)

#### Execute

In [None]:
save_trench_err_futures = {}
all_analysis_futures = {}
save_trenches_futures = {}
save_trench_err_futures = {}

all_trench_bboxes_futures = {}  # TODO: just for debugging

for filename, filename_frames in selected_frames.groupby("filename"):
    # analysis_futures = {}
    trench_bboxes_futures = {}
    trench_err_futures = {}
    for position, frames in filename_frames.groupby("position"):
        key = (filename, position)
        frame_to_segment = frames.loc[IDX[:, :, [segmentation_channel], 0], :]
        trenches_future = client.submit(
            do_find_trenches, *frame_to_segment.index[0], priority=10
        )
        trench_err_futures[key] = client.submit(do_get_trench_err, trenches_future)
        trench_bboxes_future = client.submit(
            do_trenches_to_bboxes, trenches_future, (filename, position), priority=10
        )
        trench_bboxes_futures[key] = trench_bboxes_future
        all_trench_bboxes_futures[key] = trench_bboxes_future
        analysis_future = client.submit(
            do_measure_and_write,
            trench_bboxes_future,
            frames,
            measurement_func=_measurement_func,
            # measurement_func=None,
            # segmentation_func=None,
            measure_channels=measure_channels,
            segmentation_channel=segmentation_channel,
            return_none=True,
            write=True,
            filename_func=filename_func,
        )
        all_analysis_futures[key] = analysis_future
    # save trenches
    trenches_filename = filename_func(
        kind="trenches", extension="parquet", filename=filename
    )
    save_trenches_futures[filename] = client.submit(
        do_save_trenches,
        list(dict(sorted(trench_bboxes_futures.items())).values()),
        trenches_filename,
    )
    trench_errs_filename = filename_func(
        kind="trench_errs", extension="pickle", filename=filename
    )
    save_trench_err_futures[filename] = client.submit(
        do_serialize_to_disk,
        trench_err_futures,
        trench_errs_filename,
    )