# Imports

In [None]:
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import distributed
import holoviews as hv
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from tqdm.auto import tqdm

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.image_analysis.new as new
from paulssonlab.image_analysis import *

In [None]:
# %load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Handler

In [None]:
segmentation_channel = "RFP-Penta"
trench_detection_channel = segmentation_channel  # channel for trench detection, almost always same as segmentation_channel
measure_channels = ["RFP-Penta", "YFP-DUAL"]
fish_channels = ["RFP-Penta", "Cy5-PENTA", "Cy7"]

In [None]:
import logging


class Pipeline:
    def __init__(self, output_dir):
        self.logger = logging.getLogger("Pipeline")
        self.output_dir = Path(output_dir)
        self.state = {}
        self.array = {}
        self.table = {}

    def delayed(self, func, *args, **kwargs):
        # TODO:
        # log exceptions
        # log warnings (deduplicated, count instances)
        # optionally retry with diag if func takes "diagnostics" argument
        # log benchmarking/profiling? or collect stats, only log outliers (+ call arguments)
        return dask.delayed(func, *args, **kwargs)


def crop_trenches(img, trenches):
    crops = {}
    # TODO: the islice is just for testing (we only deal with three trenches for FOV), otherwise every dask task takes a long time
    for i, crop in it.islice(new.image.iter_crops(img, trenches), 3):
        crops[i] = crop
    return crops


def segment_trenches(crops):
    masks = {}
    for i, crop in crops.items():
        masks[i] = trench_segmentation.segment(crop)
    return masks


# TODO: this is really boilerplatey, also we want finer task granularity than doing a whole FOV at once
def measure_crops(label_images, intensity_images):
    keys = label_images.keys() & intensity_images.keys()
    return {k: measure_crop(label_images[k], intensity_images[k]) for k in keys}


def measure_crop(label_image, intensity_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            intensity_image,
            properties=(
                "label",
                "intensity_mean",
            ),
        )
    ).set_index("label")


def measure_mask_crops(label_images):
    return {k: measure_mask_crop(v) for k, v in label_images.items()}


def measure_mask_crop(label_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            properties=(
                "label",
                "area",
                "axis_major_length",
                "axis_minor_length",
                "orientation",
                "centroid",
            ),
        )
    ).set_index("label")


# TODO: use a namedtuple (or typing.NamedTuple, or dataclass) for keys so that fields are named
def handle_image(pipeline, msg):
    image = msg["image"]
    metadata = msg["metadata"]
    fov_num = metadata["fov_num"]
    t = metadata["t"]
    channel = metadata["channel"]
    raw_key = ("raw", fov_num, t, channel)
    # store raw image (in production, we won't do this, we will only store crops as we do below)
    pipeline.array[raw_key] = image
    # TODO: we need a way to store per-frame metadata and write it to disk
    trenches_key = (
        "trenches",
        fov_num,
    )
    trenches = pipeline.table.get(trenches_key)
    # check if we have done trench detection for this FOV
    if trenches is None and channel == trench_detection_channel:
        # if not, find trenches and save the resulting table
        trenches = pipeline.delayed(new.image.find_trench_bboxes)(
            image, peak_func=trench_detection.peaks.find_peaks
        )
        pipeline.table[trenches_key] = trenches
    # this list keeps track of all the raw frames that need to be cropped
    # frames for multiple channels will accumulate in this list until we get a frame for trench_detection_channel
    # if we have already processed such a frame, then keys_to_crop will contain only the current frame (raw_key)
    keys_to_crop = pipeline.state.setdefault(("keys_to_crop", fov_num), [])
    keys_to_crop.append(raw_key)
    # we only can do further processing if we have already detected trenches for this FOV
    if trenches is not None:
        for raw_to_crop in keys_to_crop:
            crop_key = ("crops", *raw_to_crop[1:])
            # save trench crops for every frame in keys_to_crop
            pipeline.array[crop_key] = pipeline.delayed(crop_trenches)(
                pipeline.array[raw_to_crop], trenches
            )
            segmentation_key = ("segmentation", fov_num, t, segmentation_channel)
            segmentation = pipeline.array.get(segmentation_key)
            if segmentation is not None:
                if crop_key[-1] in measure_channels:
                    # if we have segmentation masks for this frame, we can immediately segment only this frame
                    keys_to_measure = [crop_key]
                else:
                    keys_to_measure = []
            else:
                # we don't have a segmentation mask yet, so we need to add to the keys_to_measure list
                keys_to_measure = pipeline.state.setdefault(
                    ("keys_to_measure", fov_num, t), []
                )
                if crop_key[-1] in measure_channels:
                    # we want to measure this frame
                    keys_to_measure.append(crop_key)
                if crop_key[-1] == segmentation_channel:
                    # if this frame's channel is the segmentation channel, run segmentation
                    segmentation = pipeline.delayed(segment_trenches)(
                        pipeline.array[crop_key]
                    )
                    pipeline.array[segmentation_key] = segmentation
                    # once we have the segmentation mask, get measurements for the mask
                    pipeline.table[
                        (
                            "mask_measurements",
                            *crop_key[1:],
                        )
                    ] = pipeline.delayed(measure_mask_crops)(segmentation)
            segmentation = pipeline.array.get(segmentation_key)
            # if we now have the segmentation mask, try measuring all frames in the keys_to_measure list
            if segmentation is not None:
                for crop_to_measure in keys_to_measure:
                    measurements_key = ("measurements", *crop_to_measure[1:])
                    pipeline.table[measurements_key] = pipeline.delayed(measure_crops)(
                        segmentation, pipeline.array[crop_to_measure]
                    )
                pipeline.state.pop(("keys_to_measure", fov_num, t), None)
        pipeline.state.pop(("keys_to_crop", fov_num), None)


def handle_fish_barcode(pipeline, msg):
    pass  # TODO


# we should pick a name that's better/more intuitive than handle_message
def handle_message(pipeline, msg):
    match msg:
        case {"type": "image", **info}:
            match info:
                case {"image_type": "fish_barcode"}:
                    handle_fish_barcode(pipeline, msg)
                case other:
                    handle_image(pipeline, msg)
        case {"type": "nd2_metadata"}:
            print("got metadata")  # TODO
        case {"type": "event", **info}:
            print("event", info)
        case {"type": "done"}:
            print("DONE")
        case _:
            # this exception should be caught, we don't want malformed messages to crash the pipeline
            raise ValueError("cannot handle message", msg)

In [None]:
%%time
# filename = "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
filename = "/home/jqs1/scratch/jqs1/microscopy/220523/220523_library_test_smallfile.nd2"
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")
for msg in new.readers.send_nd2(
    filename,
    slices=dict(v=slice(1), t=slice(1)),
):
    handle_message(pipeline, msg)

In [None]:
pipeline.table

In [None]:
%%time
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")
for msg in new.readers.send_nd2(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2",
    slices=dict(v=slice(1), t=slice(1)),
):
    handle_message(pipeline, msg)

In [None]:
crops_graph = pipeline.array[("crops", 0, 0, "RFP-Penta")]
masks_graph = pipeline.array[("segmentation", 0, 0, "RFP-Penta")]

In [None]:
crops_graph.visualize()

In [None]:
%%time
# crops = crops_graph.compute(scheduler="synchronous")
crops = client.compute(crops_graph, sync=True)

In [None]:
plt.imshow(crops[0].T)

In [None]:
%%time
# masks = masks_graph.compute(scheduler="synchronous")
masks = client.compute(masks_graph, sync=True)

In [None]:
plt.imshow(masks[0].T)

In [None]:
%%time
trenches = client.compute(pipeline.table[("trenches", 0)], sync=True)

In [None]:
trenches.head()

In [None]:
pipeline.table.keys()

In [None]:
%%time
measurements = client.compute(
    pipeline.table[("measurements", 0, 0, "RFP-Penta")], sync=True
)

In [None]:
measurements[0]

In [None]:
%%time
mask_measurements = client.compute(
    pipeline.table[("mask_measurements", 0, 0, "RFP-Penta")], sync=True
)

In [None]:
mask_measurements[0]

# Run

This is how the full pipeline could be run for an experiment which has a first phase (stored in ND2) and a second FISH phase (stored in HDF5).

In [None]:
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")

In [None]:
%%time
for msg in send_nd2(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
):
    handle_message(pipeline, msg)

In [None]:
%%time
for msg in send_eaton_fish("/home/jqs1/scratch/jqs1/microscopy/220718/FISH/real_run/"):
    handle_message(pipeline, msg)

In [None]:
handle_message(pipeline, {"type": "done"})