# Imports

In [None]:
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm

IDX = pd.IndexSlice

In [None]:
from dask.diagnostics import ProgressBar

ProgressBar().register()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import *
from paulssonlab.util.ui import display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
fish_colors = {
    "BF": "#ffffff",
    "GFP": "#f44336",
    "Cy5": "#03a9f4",
    # "Cy7": "#ffeb3b"
    "Cy7": "#8bc34a",
}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Handler

In [None]:
segmentation_channel = "RFP-PENTA"
trench_detection_channel = segmentation_channel  # channel for trench detection, almost always same as segmentation_channel
measure_channels = ["RFP-PENTA", "YFP-DUAL"]
fish_channels = ["RFP-PENTA", "Cy5-PENTA", "Cy7"]

In [None]:
import logging


class Pipeline:
    def __init__(self, output_dir):
        self.logger = logging.getLogger("Pipeline")
        self.output_dir = Path(output_dir)
        self.state = {}
        self.array = {}
        self.table = {}

    def delayed(self, func, *args, **kwargs):
        # TODO:
        # log exceptions
        # log warnings (deduplicated, count instances)
        # optionally retry with diag if func takes "diagnostics" argument
        # log benchmarking/profiling? or collect stats, only log outliers (+ call arguments)
        return dask.delayed(func, *args, **kwargs)


def crop_trenches(img, trenches):
    crops = {}
    # TODO: the islice is just for testing (we only deal with three trenches for FOV), otherwise every dask task takes a long time
    for i, crop in it.islice(new.image.iter_crops(img, trenches), 3):
        crops[i] = crop
    return crops


def segment_trenches(crops):
    masks = {}
    for i, crop in crops.items():
        masks[i] = trench_segmentation.segment(crop)
    return masks


# TODO: this is really boilerplatey, also we want finer task granularity than doing a whole FOV at once
def measure_crops(label_images, intensity_images):
    keys = label_images.keys() & intensity_images.keys()
    return {k: measure_crop(label_images[k], intensity_images[k]) for k in keys}


def measure_crop(label_image, intensity_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            intensity_image,
            properties=(
                "label",
                "intensity_mean",
            ),
        )
    ).set_index("label")


def measure_mask_crops(label_images):
    return {k: measure_mask_crop(v) for k, v in label_images.items()}


def measure_mask_crop(label_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            properties=(
                "label",
                "area",
                "axis_major_length",
                "axis_minor_length",
                "orientation",
                "centroid",
            ),
        )
    ).set_index("label")


# TODO: use a namedtuple (or typing.NamedTuple, or dataclass) for keys so that fields are named
def handle_image(pipeline, msg):
    image = msg["image"]
    metadata = msg["metadata"]
    fov_num = metadata["fov_num"]
    t = metadata["t"]
    channel = metadata["channel"]
    raw_key = ("raw", fov_num, t, channel)
    # store raw image (in production, we won't do this, we will only store crops as we do below)
    pipeline.array[raw_key] = image
    # TODO: we need a way to store per-frame metadata and write it to disk
    trenches_key = (
        "trenches",
        fov_num,
    )
    trenches = pipeline.table.get(trenches_key)
    # check if we have done trench detection for this FOV
    if trenches is None and channel == trench_detection_channel:
        # if not, find trenches and save the resulting table
        trenches = pipeline.delayed(new.image.find_trench_bboxes)(
            image, peak_func=trench_detection.peaks.find_peaks
        )
        pipeline.table[trenches_key] = trenches
    # this list keeps track of all the raw frames that need to be cropped
    # frames for multiple channels will accumulate in this list until we get a frame for trench_detection_channel
    # if we have already processed such a frame, then keys_to_crop will contain only the current frame (raw_key)
    keys_to_crop = pipeline.state.setdefault(("keys_to_crop", fov_num), [])
    keys_to_crop.append(raw_key)
    # we only can do further processing if we have already detected trenches for this FOV
    if trenches is not None:
        for raw_to_crop in keys_to_crop:
            crop_key = ("crops", *raw_to_crop[1:])
            # save trench crops for every frame in keys_to_crop
            pipeline.array[crop_key] = pipeline.delayed(crop_trenches)(
                pipeline.array[raw_to_crop], trenches
            )
            segmentation_key = ("segmentation", fov_num, t, segmentation_channel)
            segmentation = pipeline.array.get(segmentation_key)
            if segmentation is not None:
                if crop_key[-1] in measure_channels:
                    # if we have segmentation masks for this frame, we can immediately segment only this frame
                    keys_to_measure = [crop_key]
                else:
                    keys_to_measure = []
            else:
                # we don't have a segmentation mask yet, so we need to add to the keys_to_measure list
                keys_to_measure = pipeline.state.setdefault(
                    ("keys_to_measure", fov_num, t), []
                )
                if crop_key[-1] in measure_channels:
                    # we want to measure this frame
                    keys_to_measure.append(crop_key)
                if crop_key[-1] == segmentation_channel:
                    # if this frame's channel is the segmentation channel, run segmentation
                    segmentation = pipeline.delayed(segment_trenches)(
                        pipeline.array[crop_key]
                    )
                    pipeline.array[segmentation_key] = segmentation
                    # once we have the segmentation mask, get measurements for the mask
                    pipeline.table[
                        (
                            "mask_measurements",
                            *crop_key[1:],
                        )
                    ] = pipeline.delayed(measure_mask_crops)(segmentation)
            segmentation = pipeline.array.get(segmentation_key)
            # if we now have the segmentation mask, try measuring all frames in the keys_to_measure list
            if segmentation is not None:
                for crop_to_measure in keys_to_measure:
                    measurements_key = ("measurements", *crop_to_measure[1:])
                    pipeline.table[measurements_key] = pipeline.delayed(measure_crops)(
                        segmentation, pipeline.array[crop_to_measure]
                    )
                pipeline.state.pop(("keys_to_measure", fov_num, t), None)
        pipeline.state.pop(("keys_to_crop", fov_num), None)


def handle_fish_barcode(pipeline, msg):
    pass  # TODO


# we should pick a name that's better/more intuitive than handle_message
def handle_message(pipeline, msg):
    match msg:
        case {"type": "image", **info}:
            match info:
                case {"image_type": "fish_barcode"}:
                    handle_fish_barcode(pipeline, msg)
                case other:
                    handle_image(pipeline, msg)
        case {"type": "nd2_metadata"}:
            print("got metadata")  # TODO
        case {"type": "event", **info}:
            print("event", info)
        case {"type": "done"}:
            print("DONE")
        case _:
            # this exception should be caught, we don't want malformed messages to crash the pipeline
            raise ValueError("cannot handle message", msg)

In [None]:
filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"

In [None]:
%%time
# filename = "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/220523/220523_library_test_smallfile.nd2"
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")
for msg in new.readers.send_nd2(
    filename,
    slices=dict(v=slice(1), t=slice(1)),
):
    handle_message(pipeline, msg)

In [None]:
pipeline.table

In [None]:
%%time
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")
for msg in new.readers.send_nd2(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2",
    slices=dict(v=slice(1), t=slice(1)),
):
    handle_message(pipeline, msg)

In [None]:
crops_graph = pipeline.array[("crops", 0, 0, "RFP-Penta")]
masks_graph = pipeline.array[("segmentation", 0, 0, "RFP-Penta")]

In [None]:
crops_graph.visualize()

In [None]:
%%time
# crops = crops_graph.compute(scheduler="synchronous")
crops = client.compute(crops_graph, sync=True)

In [None]:
plt.imshow(crops[0].T)

In [None]:
%%time
# masks = masks_graph.compute(scheduler="synchronous")
masks = client.compute(masks_graph, sync=True)

In [None]:
plt.imshow(masks[0].T)

In [None]:
%%time
trenches = client.compute(pipeline.table[("trenches", 0)], sync=True)

In [None]:
trenches.head()

In [None]:
pipeline.table.keys()

In [None]:
%%time
measurements = client.compute(
    pipeline.table[("measurements", 0, 0, "RFP-Penta")], sync=True
)

In [None]:
measurements[0]

In [None]:
%%time
mask_measurements = client.compute(
    pipeline.table[("mask_measurements", 0, 0, "RFP-Penta")], sync=True
)

In [None]:
mask_measurements[0]

# Run

This is how the full pipeline could be run for an experiment which has a first phase (stored in ND2) and a second FISH phase (stored in HDF5).

In [None]:
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")

In [None]:
%%time
for msg in send_nd2(
    "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
):
    handle_message(pipeline, msg)

In [None]:
%%time
for msg in send_eaton_fish(
    "/home/jqs1/scratch/jqs1/microscopy/220718/FISH/real_run/",
    r"fov=(?P<v>\d+)_config=(?P<c>\w+)_t=(?P<t>\d+)",
):
    handle_message(pipeline, msg)

In [None]:
handle_message(pipeline, {"type": "done"})

# Test

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"

In [None]:
nd2 = nd2reader.ND2Reader(filename)

In [None]:
nd2.metadata["channels"]

In [None]:
img = nd2.get_frame_2D(v=8, c=0, t=10)

In [None]:
display_image(img / img.max() * 4)

In [None]:
%%time
diag = util.tree()
trenches = trench_detection.find_trenches(img, diagnostics=diag)

In [None]:
%%time
diag2 = util.tree()
trenches2 = trench_detection.find_trenches(
    img, peak_func=trench_detection.peaks.find_peaks, diagnostics=diag2
)

In [None]:
trenches

In [None]:
diag["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
diag2["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
diag["label_1"]["find_trench_lines"]["hough_0"]["peak_func"].keys()

In [None]:
diag["label_1"]["find_trench_lines"]["hough_0"]["peak_func"]["spectrum"]

In [None]:
diag["labeling"].keys()

In [None]:
%%time
diag3 = util.tree()
img_bin = trench_detection.set_finding.binarize_trench_image(img, diagnostics=diag3)

In [None]:
img_bin

In [None]:
plt.hist(img.flat, bins=300, log=True);

In [None]:
from paulssonlab.image_analysis.image import (
    gaussian_box_approximation,
    normalize_componentwise,
    remove_large_objects,
)

In [None]:
lowpass_radius = 500
img_lowpass = gaussian_box_approximation(img, lowpass_radius)

In [None]:
0

In [None]:
%%time
rb = skimage.restoration.rolling_ball(img, radius=30)

In [None]:
display_image(rb, scale=True)

In [None]:
display_image((img - rb) / img.max() * 20)

In [None]:
?skimage.filters.threshold_sauvola

In [None]:
display_image(img > skimage.filters.threshold_sauvola(img, window_size=7))

In [None]:
display_image(img > skimage.filters.threshold_otsu

In [None]:
display_image(img / img.max() * 30)

In [None]:
display_image(img - img_lowpass, scale=True)

In [None]:
display_image(img - img_lowpass, scale=True)

In [None]:
display_image(img_bin[1])

In [None]:
diag["labeling"]["binarize_trench_image"].keys()

In [None]:
diag["labeling"]["binarize_trench_image"]["thresholded_image"]

In [None]:
diag["label_1"]["find_trench_lines"]["hough_0"]["peak_func"]["refined_points"]

In [None]:
diag["label_2"]["find_trench_ends"].keys()

In [None]:
?trench_detection.hough.find_periodic_lines

In [None]:
%%time
h = trench_detection.hough.hough_line_intensity(
    img, theta=np.linspace(-np.pi / 5, np.pi / 5, 10)
)

In [None]:
display_image(h[0].T, scale=True)

In [None]:
h[0].shape

In [None]:
np.deg2rad(5) / np.pi

# Manual FISH trench crops

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"
fish_filename = Path(filename).parent / "FISH/real_run"

In [None]:
nd2 = nd2reader.ND2Reader(filename)

In [None]:
nd2.sizes

In [None]:
img = nd2.get_frame_2D(v=8, c=0, t=180)

In [None]:
k1 = 8.947368421052635e-10
img_t = image.correct_radial_distortion(img, k1=k1)

In [None]:
%%time
# diag = util.tree()
diag = None
trenches, info = trench_detection.find_trenches(
    img_t,
    # angle=np.deg2rad(0.001),
    join_info=False,
    width=12,
    # width_to_line_width_ratio=2,
    # width_to_pitch_ratio=None,
    # peak_func=trench_detection.peaks.find_peaks,
    diagnostics=diag,
)

In [None]:
def crop_trenches(img, trenches):
    crops = {}
    # for i, crop in it.islice(geometry.iter_crops(img, trenches), 10, 13):
    for i, crop in geometry.iter_crops(img, trenches):
        crops[i] = crop
    return crops


def stack_crops(crops, channels, timepoints):
    stacks = {}
    for (t, channel), frame_crops in crops.items():
        channel_idx = channels.index(channel)
        timepoint_idx = timepoints.index(t)
        for trench_idx, trench_slice in frame_crops.items():
            if trench_idx not in stacks:
                stacks[trench_idx] = zarr.create(
                    (len(channels), len(timepoints), *trench_slice.shape),
                    dtype=trench_slice.dtype,
                    fill_value=np.nan,
                )
            stacks[trench_idx][channel_idx, timepoint_idx, :, :] = trench_slice
    return stacks

In [None]:
def calibrate_image(img, k1=0):
    img = skimage.img_as_float32(img)
    img = image.correct_radial_distortion(img, k1=k1)
    return img

In [None]:
%%time
delayed = util.get_delayed(True)
fish_frames = {}
fish_crops = {}
fish_channels = set()
fish_timepoints = set()
for msg in readers.send_eaton_fish(
    fish_filename,
    r"fov=(?P<v>\d+)_config=(?P<c>\w+)_t=(?P<t>\d+)",
    slices=dict(t=None, v=[8]),
    delayed=delayed,
):
    # print(msg["metadata"],msg["image"].shape)
    fish_img = msg["image"]
    fish_img_corrected = delayed(calibrate_image)(fish_img, k1=k1)
    t = msg["metadata"]["t"]
    channel = msg["metadata"]["channel"]
    fish_channels.add(channel)
    fish_timepoints.add(t)
    fish_frames[(t, channel)] = fish_img_corrected
    fish_crops[(t, channel)] = delayed(crop_trenches)(fish_img_corrected, trenches)
fish_channels = list(sorted(fish_channels))
fish_timepoints = list(sorted(fish_timepoints))
fish_stacks = delayed(stack_crops)(fish_crops, fish_channels, fish_timepoints)

In [None]:
fish_channel_colors = [fish_colors[ch] for ch in fish_channels]

In [None]:
fish_frames0, fish_stacks0 = dask.compute(fish_frames, fish_stacks)

In [None]:
fish_channels

In [None]:
fish_timepoints

In [None]:
fish_stacks0[10].info

In [None]:
for msg in new.readers.send_nd2(
    filename,
    slices=dict(v=slice(1), t=slice(1)),
):
    handle_message(pipeline, msg)

# Segmentation

In [None]:
x = fish_stacks0[13][1:, :9]
# x = x - x.min(axis=1)[:,np.newaxis,:,:]

In [None]:
def weighted_mean(ary):
    ary = ary - ary.min(axis=1)[:, np.newaxis, :, :]
    # lmbda = (ary.max(axis=1) - ary.min(axis=1))[:,np.newaxis,:,:]
    lmbda = ary.max(axis=1)[:, np.newaxis, :, :]
    w = (
        1
        / 3
        * (lmbda / lmbda.sum(axis=(2, 3))[:, :, np.newaxis, np.newaxis]).sum(axis=0)[
            np.newaxis, :, :, :
        ]
    )
    if w.sum() == 0:
        return None
    return np.average(ary, axis=(2, 3), weights=np.broadcast_to(w, ary.shape))

In [None]:
fish_metrics = {
    idx: weighted_mean(np.asarray(stack[1:, :9]))
    for idx, stack in tqdm(fish_stacks0.items())
}

In [None]:
sum(1 for x in fish_metrics.values() if x is None)

In [None]:
fish_metrics[1][0]

In [None]:
bit_names = [(ch, str(t)) for ch in fish_channels[1:] for t in fish_timepoints[:-1]]

In [None]:
fish_metrics_df = pd.DataFrame.from_dict(
    {
        trench_idx: ary.flatten()
        for trench_idx, ary in fish_metrics.items()
        if ary is not None
    },
    columns=pd.MultiIndex.from_tuples(bit_names, names=["channel", "timepoint"]),
    orient="index",
).rename_axis(index="trench_idx")
fish_metrics_df

In [None]:
fish_metrics_df2 = fish_metrics_df.melt(ignore_index=False).reset_index()

In [None]:
fish_metrics_df2

In [None]:
fish_thresholds = {"GFP": 0.007, "Cy5": 0.005, "Cy7": 0.001}

In [None]:
fish_metrics_df2["ground_truth"] = fish_metrics_df2["value"] > fish_metrics_df2[
    "channel"
].map(fish_thresholds)

In [None]:
(fish_metrics_df2.groupby("trench_idx").sum("ground_truth") == 0).sum()

In [None]:
fish_metrics_df2.groupby("channel").apply(lambda x: x["ground_truth"].sum() / len(x))

In [None]:
(fish_metrics_df2.groupby("channel").sum("ground_truth") == 0).sum()

In [None]:
# idx = 1901
idx = 2350

In [None]:
x = fish_stacks0[idx][1:, :9]

In [None]:
fish_metrics_df2[fish_metrics_df2["trench_idx"] == idx]

In [None]:
display_image(image.unstack_multichannel(x))

In [None]:
y = x - x.min(axis=1)[:, np.newaxis, :, :]

In [None]:
display_image(image.unstack_multichannel(y))

In [None]:
plt.imshow(weighted_mean(y))

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
    )
)

In [None]:
ds = hv.Dataset(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value")

In [None]:
ds.to(hv.Violin, ["timepoint", "ground_truth"]).layout("channel").opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
        axiswise=True,
    )
).cols(1)

In [None]:
z = ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")

In [None]:
_stacked_violins = (
    ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")
)

hv.Layout([v.redim(value=k) for k, v in _stacked_violins.items()]).opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        # violin_width=3,
        inner=None,
        bandwidth=0.2,
        cut=0.05,
    )
).cols(1)

In [None]:
# fish_metrics_df2.groupby("channel").apply(lambda x: hv.Violin(x))

In [None]:
hv.Layout()

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(logy=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Dataset(df, ["ground_truth"], "value").to(
            hv.Distribution
        )
        # .overlay("ground_truth")
        # hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(show_legend=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): (
            hv.Distribution(df[df["ground_truth"]], "value", label="On")
            * hv.Distribution(df[~df["ground_truth"]], "value", label="Off")
        ).redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
).opts(hv.opts.Distribution(show_legend=True, bandwidth=0.3, cut=0.05))

In [None]:
from bokeh.sampledata.iris import flowers
from holoviews.operation import gridmatrix

iris_ds = hv.Dataset(flowers)

In [None]:
iris_ds

In [None]:
fish_metrics_df3 = fish_metrics_df.set_axis(
    ["_".join(c) for c in fish_metrics_df.columns], axis=1
)

In [None]:
fish_metrics_df3 = fish_metrics_df3.loc[
    :, [*fish_metrics_df3.columns[3:6], *fish_metrics_df3.columns[13:16]]
]

In [None]:
density_grid = gridmatrix(
    hv.Dataset(fish_metrics_df3), diagonal_type=hv.Distribution, chart_type=hv.Bivariate
)

In [None]:
density_grid