# Imports

In [None]:
import itertools as it
import operator
import os
import pickle
import re
from collections import namedtuple
from functools import partial, reduce
from pathlib import Path

import dask
import distributed
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import panel as pn
import pyarrow as pa
import pyarrow.parquet as pq
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from tqdm.auto import tqdm

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.image_analysis.new as new
from paulssonlab.image_analysis import *

In [None]:
#%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2.5GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

In [None]:
cluster.adapt(maximum=20)

# Handler

In [None]:
drift_correction_channel = "Phase-Fluor"
segmentation_channel = "RFP-PENTA"
trench_detection_channel = segmentation_channel  # channel for trench detection, almost always same as segmentation_channel
measure_channels = ["RFP-PENTA", "GFP-PENTA"]
fish_channels = ["RFP-Penta", "Cy5-PENTA", "Cy7"]

In [None]:
import logging


class Pipeline:
    def __init__(self, output_dir):
        self.logger = logging.getLogger("Pipeline")
        self.output_dir = Path(output_dir)
        self.state = {}
        self.array = {}
        self.table = {}

    def delayed(self, func, *args, **kwargs):
        # TODO:
        # log exceptions
        # log warnings (deduplicated, count instances)
        # optionally retry with diag if func takes "diagnostics" argument
        # log benchmarking/profiling? or collect stats, only log outliers (+ call arguments)
        return dask.delayed(func, *args, **kwargs)


def crop_trenches(img, trenches):
    crops = {}
    # TODO: the islice is just for testing (we only deal with three trenches for FOV), otherwise every dask task takes a long time
    # for i, crop in it.islice(new.image.iter_crops(img, trenches), 3):
    for i, crop in new.image.iter_crops(img, trenches):
        crops[i] = crop
    return crops


def segment_trenches(crops):
    masks = {}
    for i, crop in crops.items():
        try:
            masks[i] = trench_segmentation.segment(crop)
        except:
            pass
    return masks


# TODO: this is really boilerplatey, also we want finer task granularity than doing a whole FOV at once
def measure_crops(label_images, intensity_images):
    keys = label_images.keys() & intensity_images.keys()
    return {k: measure_crop(label_images[k], intensity_images[k]) for k in keys}


def measure_crop(label_image, intensity_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            intensity_image,
            properties=(
                "label",
                "intensity_mean",
            ),
        )
    ).set_index("label")


def measure_mask_crops(label_images):
    return {k: measure_mask_crop(v) for k, v in label_images.items()}


def measure_mask_crop(label_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            properties=(
                "label",
                "area",
                "axis_major_length",
                "axis_minor_length",
                "orientation",
                "centroid",
            ),
        )
    ).set_index("label")


# TODO: use a namedtuple (or typing.NamedTuple, or dataclass) for keys so that fields are named
def handle_image(pipeline, msg):
    image = msg["image"]
    metadata = msg["metadata"]
    fov_num = metadata["fov_num"]
    t = metadata["t"]
    channel = metadata["channel"]
    raw_key = ("raw", fov_num, t, channel)
    # store raw image (in production, we won't do this, we will only store crops as we do below)
    pipeline.array[raw_key] = image
    # TODO: we need a way to store per-frame metadata and write it to disk
    trenches_key = (
        "trenches",
        fov_num,
    )
    trenches = pipeline.table.get(trenches_key)
    # check if we have done trench detection for this FOV
    if trenches is None and channel == trench_detection_channel:
        # if not, find trenches and save the resulting table
        trenches = pipeline.delayed(new.image.find_trench_bboxes)(
            image, peak_func=trench_detection.peaks.find_peaks
        )
        pipeline.table[trenches_key] = trenches
    # this list keeps track of all the raw frames that need to be cropped
    # frames for multiple channels will accumulate in this list until we get a frame for trench_detection_channel
    # if we have already processed such a frame, then keys_to_crop will contain only the current frame (raw_key)
    keys_to_crop = pipeline.state.setdefault(("keys_to_crop", fov_num), [])
    keys_to_crop.append(raw_key)
    # we only can do further processing if we have already detected trenches for this FOV
    if trenches is not None:
        for raw_to_crop in keys_to_crop:
            crop_key = ("crops", *raw_to_crop[1:])
            # save trench crops for every frame in keys_to_crop
            pipeline.array[crop_key] = pipeline.delayed(crop_trenches)(
                pipeline.array[raw_to_crop], trenches
            )
            segmentation_key = ("segmentation", fov_num, t, segmentation_channel)
            segmentation = pipeline.array.get(segmentation_key)
            if segmentation is not None:
                if crop_key[-1] in measure_channels:
                    # if we have segmentation masks for this frame, we can immediately segment only this frame
                    keys_to_measure = [crop_key]
                else:
                    keys_to_measure = []
            else:
                # we don't have a segmentation mask yet, so we need to add to the keys_to_measure list
                keys_to_measure = pipeline.state.setdefault(
                    ("keys_to_measure", fov_num, t), []
                )
                if crop_key[-1] in measure_channels:
                    # we want to measure this frame
                    keys_to_measure.append(crop_key)
                if crop_key[-1] == segmentation_channel:
                    # if this frame's channel is the segmentation channel, run segmentation
                    segmentation = pipeline.delayed(segment_trenches)(
                        pipeline.array[crop_key]
                    )
                    pipeline.array[segmentation_key] = segmentation
                    # once we have the segmentation mask, get measurements for the mask
                    pipeline.table[
                        (
                            "mask_measurements",
                            *crop_key[1:],
                        )
                    ] = pipeline.delayed(measure_mask_crops)(segmentation)
            segmentation = pipeline.array.get(segmentation_key)
            # if we now have the segmentation mask, try measuring all frames in the keys_to_measure list
            if segmentation is not None:
                for crop_to_measure in keys_to_measure:
                    measurements_key = ("measurements", *crop_to_measure[1:])
                    pipeline.table[measurements_key] = pipeline.delayed(measure_crops)(
                        segmentation, pipeline.array[crop_to_measure]
                    )
                pipeline.state.pop(("keys_to_measure", fov_num, t), None)
        pipeline.state.pop(("keys_to_crop", fov_num), None)


def handle_fish_barcode(pipeline, msg):
    pass  # TODO


# we should pick a name that's better/more intuitive than handle_message
def handle_message(pipeline, msg):
    match msg:
        case {"type": "image", **info}:
            match info:
                case {"image_type": "fish_barcode"}:
                    handle_fish_barcode(pipeline, msg)
                case other:
                    handle_image(pipeline, msg)
        case {"type": "nd2_metadata"}:
            print("got metadata")  # TODO
        case {"type": "event", **info}:
            print("event", info)
        case {"type": "done"}:
            print("DONE")
        case _:
            # this exception should be caught, we don't want malformed messages to crash the pipeline
            raise ValueError("cannot handle message", msg)

In [None]:
%%time
# filename = "/home/jqs1/scratch/jqs1/microscopy/210511/RBS_ramp.nd2"
filename = "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")
for msg in new.readers.send_nd2(
    filename,
    # slices=dict(v=[30], t=slice(40,None)),
    slices=dict(v=[30], t=slice(None)),
):
    handle_message(pipeline, msg)

In [None]:
%%time
futures = util.apply_map_futures(client.compute, (pipeline.table, pipeline.array))

In [None]:
%%time
table, array = client.gather(futures)

## Save outputs to pickle

In [None]:
pickle_filename = "/home/jqs1/group/221108rbsdeglibrary_1.pickle"

In [None]:
%%time
with open(pickle_filename, "wb") as f:
    pickle.dump((table, array), f)

In [None]:
!du -hs "/home/jqs1/group/221108rbsdeglibrary_1.pickle"

## Load outputs from pickle

In [None]:
# pickle_filename = "/home/jqs1/group/221108rbsdeglibrary_1.pickle"
pickle_filename = "/home/jqs1/group/221108rbsdeglibrary_1_table.pickle"

In [None]:
%%time
with open(pickle_filename, "rb") as f:
    table, array = pickle.load(f)

In [None]:
!mamba install -y ibis-framework

# Helper functions

In [None]:
def reformat_table(
    table, prefix, flatten_column_names=False, truncate_column_names=False
):
    if not isinstance(prefix, tuple):
        prefix = (prefix,)
    keys = sorted([k for k in table.keys() if k[: len(prefix)] == prefix])
    if not keys:
        return None
    df = pd.concat(
        {k[len(prefix) :]: pd.concat(table[k], names=["roi"]) for k in keys},
        names=["fov", "t", "channel"],
    )
    df = df.unstack("channel")
    if flatten_column_names and truncate_column_names:
        raise ValueError(
            "flatten_column_names and truncate_column_names cannot both be True"
        )
    if flatten_column_names:
        # replace MultiIndex with Index of slash-separated names like "GFP-PENTA/mean_intensity"
        df.columns = ["/".join(col[::-1]) for col in df.columns.values]
        #df.columns = [re.sub(r"^(\w+)-[^/]*/intensity_mean", r"\1", col) for col in df.columns.values]
    elif truncate_column_names:
        # replace MultiIndex with Index of slash-separated names with only the last component,
        # e.g., "mean_intensity" instead of ("RFP-Penta", "mean_intensity")
        df.columns = [col[0] for col in df.columns.values]
    return df

In [None]:
def stack_crops(array, prefix, fov, channel):
    keys = sorted(
        [
            k
            for k in array.keys()
            if len(k) == 4 and k[:2] == (prefix, fov) and k[3] == channel
        ]
    )
    trenches = reduce(operator.and_, [array[k].keys() for k in keys])
    crops = {}
    for trench in list(trenches):
        crops[trench] = np.stack([array[k][trench] for k in keys])
    return crops

In [None]:
def unstack(ary):
    return np.swapaxes(ary, 0, 1).reshape(ary.shape[1], -1)


def pad_and_stack(arys, fill_value=0):
    shape = np.max([ary.shape for ary in arys], axis=0)
    return np.stack(
        [
            np.pad(
                ary,
                ((shape[0] - ary.shape[0], 0), (shape[1] - ary.shape[1], 0)),
                constant_values=fill_value,
            )
            for ary in arys
        ]
    )


def pad_unstack(arys):
    return unstack(pad_and_stack(arys))

# Streaming

In [None]:
import hvplot.streamz
from streamz.dataframe import PeriodicDataFrame

In [None]:
state = {}


def poll_table(last, now, **kwargs):
    counter = state.setdefault("counter", 0) + 1
    state["counter"] = counter
    table_subset = {k: v for k, v in table.items() if len(k) == 4 and k[2] == counter}
    measurements = reformat_table(
        table_subset, "measurements", flatten_column_names=True
    )
    mask_measurements = reformat_table(
        table_subset, "mask_measurements", truncate_column_names=True
    )
    if measurements is not None and mask_measurements is not None:
        all_measurements = pd.concat((measurements, mask_measurements), axis=1)
        if state.get("df") is not None:
            state["df"] = pd.concat((state["df"], all_measurements))
        else:
            state["df"] = all_measurements
    # freq = kwargs.get("freq", pd.Timedelta("50ms"))
    # index = pd.date_range(start=last + freq, end=now, freq=freq)
    # return pd.DataFrame({'x': np.random.random(len(index))}, index=index)
    return state["df"]


measurements_stream = PeriodicDataFrame(poll_table, interval="300ms")

In [None]:
def filter_plots(data, singles, pairs):
    #return pn.Column(pn.pane.HoloViews(hv.Layout([data.hvplot.kde("area", yaxis="bare"), data.hvplot.kde("axis_major_length", yaxis="bare")])))
    # sel = link_selections.instance()
    return pn.Column(
        # hv.Layout(
        #    [hv.Distribution(data, k).opts(height=250, width=200, yaxis="bare") for idx, k in enumerate(singles)]
        # ),
        pn.pane.HoloViews(
            #link_selections(
                hv.Layout(
                    [
                        # hv.Distribution(data, k).opts(height=200, width=200, yaxis="bare")
                        #data.hvplot.kde(k, height=200, responsive=True, yaxis="bare")
                        data.hvplot.kde(k, height=200, width=200, yaxis="bare", backlog=100)
                        for idx, k in enumerate(singles)
                    ]
                ).cols(6),
            #),
            #sizing_mode="stretch_width",
        ),
        #hv.Layout([hv.Scatter(data, *k) for k in pairs]),
        pn.pane.HoloViews(
            # hv.Layout([hv.Scatter(data, *k) for k in pairs]), sizing_mode="stretch_width"
            # link_selections(
                hv.Layout(
                    [data.hvplot.scatter(*k, height=300, width=300, hover=False, size=2, backlog=100) for k in pairs]
                ),
            # ),
            sizing_mode="stretch_width",
        ),
    )


p = filter_plots(
    measurements_stream,
    [
       "RFP-Penta/intensity_mean",
       "YFP-DUAL/intensity_mean",
       "area",
       "axis_minor_length",
       "axis_major_length",
    ],
    [
       ("RFP-Penta/intensity_mean", "YFP-DUAL/intensity_mean"),
       ("axis_minor_length", "axis_major_length"),
       ("area", "RFP-Penta/intensity_mean"),
    ],
)
p

In [None]:
measurements_stream["area"].mean()

In [None]:
measurements_stream["area"].hvplot.kde()

In [None]:
measurements_stream.hvplot.scatter(
    "area", "axis_major_length", backlog=100000
).redim.range(area=(0, 300), axis_major_length=(0, 45)).opts(size=2)

In [None]:
measurements_stream.hvplot.bivariate("area", "axis_major_length")

In [None]:
measurements_stream.hvplot.bivariate("area", "axis_major_length").redim.range(
    area=(0, 300), axis_major_length=(0, 45)
).opts(filled=True, bandwidth=10)

# Tabular visualizations

In [None]:
%%time
measurements = reformat_table(table, "measurements", flatten_column_names=True)

In [None]:
%%time
mask_measurements = reformat_table(
    table, "mask_measurements", truncate_column_names=True
)

In [None]:
all_measurements = pd.concat((measurements, mask_measurements), axis=1)

In [None]:
all_measurements_subset = all_measurements[:1000]
all_measurements_subset = all_measurements_subset[all_measurements_subset["RFP"] < 20000]

## Median+MAD (median absolute deviation) plots

In [None]:
import astropy.stats

In [None]:
all_measurements

In [None]:
%%time
measurements_subset = all_measurements.reset_index()[
    [
        "t",
        "RFP-Penta/intensity_mean",
        "YFP-DUAL/intensity_mean",
        "area",
        "axis_major_length",
        "axis_minor_length",
    ]
]
medians = measurements_subset.groupby(["t"]).agg(
    ["median", astropy.stats.median_absolute_deviation]
)


def get_limits(x):
    x = x.droplevel(0, axis=1)
    return pd.DataFrame(
        {
            "lower": x["median"] - x["median_absolute_deviation"],
            "upper": x["median"] + x["median_absolute_deviation"],
        }
    )


limits = medians.groupby(level=0, axis=1).apply(get_limits)

In [None]:
def plot_median_mad(observable, medians, limits):
    medians2 = medians[observable].reset_index()
    limits2 = limits[observable].reset_index()
    mean_plot = medians2.hvplot.line("t", "median", logy=True)
    noise_plot = limits2.hvplot.area(
        x="t", y="lower", y2="upper", stacked=False, alpha=0.2, logy=True
    )
    return (mean_plot * noise_plot).opts(width=800, height=300)

In [None]:
plot_median_mad("YFP-DUAL/intensity_mean", medians, limits)

In [None]:
plot_median_mad("RFP-Penta/intensity_mean", medians, limits)

In [None]:
(
    plot_median_mad("RFP-Penta/intensity_mean", medians, limits)
    * plot_median_mad("YFP-DUAL/intensity_mean", medians, limits)
    * plot_median_mad("area", medians, limits)
)

## Heatmap

In [None]:
import hvplot.xarray
import xarray as xr

In [None]:
%%time
channel = "YFP-DUAL/intensity_mean"
measurements_subset = all_measurements.reset_index()[
    [
        "t",
        "RFP-Penta/intensity_mean",
        "YFP-DUAL/intensity_mean",
        "area",
        "axis_major_length",
        "axis_minor_length",
    ]
]
bins = np.geomspace(
    measurements_subset[channel].min(), measurements_subset[channel].max(), 100
)
heatmap = measurements_subset.groupby(["t"]).apply(
    lambda x: pd.Series(np.histogram(x[channel], bins=bins)[0], index=bins[:-1])
)
heatmap.columns.name = channel
heatmap = xr.DataArray(heatmap.T)

In [None]:
heatmap.hvplot.quadmesh(
    cmap="blues",
    logy=True,
    logz=True,
    clim=(1, 1e4),
)

# Interactive selections

In [None]:
from holoviews.selection import link_selections

In [None]:
# for weirdness with responsive=True in holoviews/hvplot
# SEE: https://github.com/holoviz/panel/issues/1394

In [None]:
def filter_plots(data, singles, pairs):
    # sel = link_selections.instance()
    return pn.Column(
        # hv.Layout(
        #    [hv.Distribution(data, k).opts(height=250, width=200, yaxis="bare") for idx, k in enumerate(singles)]
        # ),
        pn.pane.HoloViews(
            #link_selections(
                hv.Layout(
                    [
                        # hv.Distribution(data, k).opts(height=200, width=200, yaxis="bare")
                        data.hvplot.kde(k, height=200, responsive=True, yaxis="bare")
                        for idx, k in enumerate(singles)
                    ]
                ).cols(6),
            #),
            sizing_mode="stretch_width",
        ),
        # hv.Layout([hv.Scatter(data, *k) for k in pairs]),
        pn.pane.HoloViews(
            # hv.Layout([hv.Scatter(data, *k) for k in pairs]), sizing_mode="stretch_width"
            link_selections(
                hv.Layout(
                    [data.hvplot.scatter(*k, height=300, width=300, hover=False) for k in pairs]
                ),
            ),
            sizing_mode="stretch_width",
        ),
    )


p = filter_plots(
    all_measurements_subset,
    [
        "RFP-Penta/intensity_mean",
        "YFP-DUAL/intensity_mean",
        "area",
        "axis_minor_length",
        "axis_major_length",
    ],
    [
        ("RFP-Penta/intensity_mean", "YFP-DUAL/intensity_mean"),
        ("axis_minor_length", "axis_major_length"),
        ("area", "RFP-Penta/intensity_mean"),
    ],
)
p

# Image visualizations

In [None]:
%%time
rfp_stacks = stack_crops(array, "crops", 30, "RFP-Penta")

In [None]:
%%time
yfp_stacks = stack_crops(array, "crops", 30, "YFP-DUAL")

## Kymographs

In [None]:
a = rfp_stacks[200]
plt.figure(figsize=(20, 20))
plt.imshow(np.swapaxes(a, 0, 1).reshape(a.shape[1], -1))

In [None]:
a = yfp_stacks[300]
plt.figure(figsize=(20, 20))
plt.imshow(np.swapaxes(a, 0, 1).reshape(a.shape[1], -1))

## Many-trenches viewer

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(pad_unstack([d2[i][93] for i in range(330, 370)]))