# Imports

In [None]:
import itertools as it
import operator
import os
import pickle
import re
from collections import namedtuple
from functools import partial, reduce
from pathlib import Path

import dask
import distributed
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import panel as pn
import pyarrow as pa
import pyarrow.parquet as pq
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from tqdm.auto import tqdm

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.image_analysis.new as new
from paulssonlab.image_analysis import *

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

## Load outputs from pickle

In [None]:
pickle_filename = "/home/jqs1/group/221108rbsdeglibrary_1.pickle"

In [None]:
%%time
with open(pickle_filename, "rb") as f:
    table, array = pickle.load(f)

# Helper functions

In [None]:
def reformat_table(
    table, prefix, flatten_column_names=False, truncate_column_names=False
):
    keys = sorted([k for k in table.keys() if k[0] == prefix])
    if not keys:
        return None
    df = pd.concat(
        {k[1:]: pd.concat(table[k], names=["roi"]) for k in keys},
        names=["fov", "t", "channel"],
    )
    df = df.unstack("channel")
    if flatten_column_names and truncate_column_names:
        raise ValueError(
            "flatten_column_names and truncate_column_names cannot both be True"
        )
    if flatten_column_names:
        # replace MultiIndex with Index of slash-separated names like "GFP-PENTA/mean_intensity"
        df.columns = ["/".join(col[::-1]) for col in df.columns.values]
        # df.columns = [re.sub(r"^(\w+)-[^/]*/intensity_mean", r"\1", col) for col in df.columns.values]
    elif truncate_column_names:
        # replace MultiIndex with Index of slash-separated names with only the last component,
        # e.g., "mean_intensity" instead of ("RFP-Penta", "mean_intensity")
        df.columns = [col[0] for col in df.columns.values]
    return df

In [None]:
def stack_crops(array, prefix, fov, channel):
    keys = sorted(
        [
            k
            for k in array.keys()
            if len(k) == 4 and k[:2] == (prefix, fov) and k[3] == channel
        ]
    )
    trenches = reduce(operator.and_, [array[k].keys() for k in keys])
    crops = {}
    for trench in list(trenches):
        crops[trench] = np.stack([array[k][trench] for k in keys])
    return crops

In [None]:
def unstack(ary):
    return np.swapaxes(ary, 0, 1).reshape(ary.shape[1], -1)


def pad_and_stack(arys, fill_value=0):
    shape = np.max([ary.shape for ary in arys], axis=0)
    return np.stack(
        [
            np.pad(
                ary,
                ((shape[0] - ary.shape[0], 0), (shape[1] - ary.shape[1], 0)),
                constant_values=fill_value,
            )
            for ary in arys
        ]
    )


def pad_unstack(arys):
    return unstack(pad_and_stack(arys))

# Reformat data

In [None]:
%%time
measurements = reformat_table(table, "measurements", flatten_column_names=True)

In [None]:
%%time
mask_measurements = reformat_table(
    table, "mask_measurements", truncate_column_names=True
)

In [None]:
all_measurements = pd.concat((measurements, mask_measurements), axis=1)

In [None]:
all_measurements_subset = all_measurements[:1000]
all_measurements_subset = all_measurements_subset[
    all_measurements_subset["RFP-Penta/intensity_mean"] < 20000
]

# Streaming

In [None]:
import hvplot.streamz
from streamz.dataframe import PeriodicDataFrame

In [None]:
state = {}


def poll_table(last, now, **kwargs):
    counter = state.setdefault("counter", 0) + 1
    state["counter"] = counter
    return all_measurements.loc[IDX[:, :counter, :, :]]


measurements_stream = PeriodicDataFrame(poll_table, interval="300ms")

In [None]:
def filter_plots(data, singles, pairs):
    # return pn.Column(pn.pane.HoloViews(hv.Layout([data.hvplot.kde("area", yaxis="bare"), data.hvplot.kde("axis_major_length", yaxis="bare")])))
    # sel = link_selections.instance()
    return pn.Column(
        # hv.Layout(
        #    [hv.Distribution(data, k).opts(height=250, width=200, yaxis="bare") for idx, k in enumerate(singles)]
        # ),
        pn.pane.HoloViews(
            # link_selections(
            hv.Layout(
                [
                    # hv.Distribution(data, k).opts(height=200, width=200, yaxis="bare")
                    # data.hvplot.kde(k, height=200, responsive=True, yaxis="bare")
                    data.hvplot.kde(k, height=200, width=200, yaxis="bare", backlog=100)
                    for idx, k in enumerate(singles)
                ]
            ).cols(6),
            # ),
            # sizing_mode="stretch_width",
        ),
        # hv.Layout([hv.Scatter(data, *k) for k in pairs]),
        pn.pane.HoloViews(
            # hv.Layout([hv.Scatter(data, *k) for k in pairs]), sizing_mode="stretch_width"
            # link_selections(
            hv.Layout(
                [
                    data.hvplot.scatter(
                        *k,
                        height=300,
                        width=300,
                        hover=False,
                        size=2,
                        backlog=1000,
                    )
                    for k in pairs
                ]
            ),
            # ),
            sizing_mode="stretch_width",
        ),
    )


p = filter_plots(
    measurements_stream,
    [
        "RFP-Penta/intensity_mean",
        "YFP-DUAL/intensity_mean",
        "area",
        "axis_minor_length",
        "axis_major_length",
    ],
    [
        ("RFP-Penta/intensity_mean", "YFP-DUAL/intensity_mean"),
        ("axis_minor_length", "axis_major_length"),
        ("area", "RFP-Penta/intensity_mean"),
    ],
)
p

# Tabular visualizations

## Median+MAD (median absolute deviation) plots

In [None]:
import astropy.stats

In [None]:
%%time
measurements_subset = all_measurements.reset_index()[
    [
        "t",
        "RFP-Penta/intensity_mean",
        "YFP-DUAL/intensity_mean",
        "area",
        "axis_major_length",
        "axis_minor_length",
    ]
]
medians = measurements_subset.groupby(["t"]).agg(
    ["median", astropy.stats.median_absolute_deviation]
)


def get_limits(x):
    x = x.droplevel(0, axis=1)
    return pd.DataFrame(
        {
            "lower": x["median"] - x["median_absolute_deviation"],
            "upper": x["median"] + x["median_absolute_deviation"],
        }
    )


limits = medians.groupby(level=0, axis=1).apply(get_limits)

In [None]:
def plot_median_mad(observable, medians, limits):
    medians2 = medians[observable].reset_index()
    limits2 = limits[observable].reset_index()
    mean_plot = medians2.hvplot.line("t", "median", logy=True)
    noise_plot = limits2.hvplot.area(
        x="t", y="lower", y2="upper", stacked=False, alpha=0.2, logy=True
    )
    return (mean_plot * noise_plot).opts(width=800, height=300)

In [None]:
plot_median_mad("YFP-DUAL/intensity_mean", medians, limits)

In [None]:
plot_median_mad("RFP-Penta/intensity_mean", medians, limits)

In [None]:
(
    plot_median_mad("RFP-Penta/intensity_mean", medians, limits)
    * plot_median_mad("YFP-DUAL/intensity_mean", medians, limits)
    # * plot_median_mad("area", medians, limits)
)

## Heatmap

In [None]:
import hvplot.xarray
import xarray as xr

In [None]:
%%time
channel = "YFP-DUAL/intensity_mean"
measurements_subset = all_measurements.reset_index()[
    [
        "t",
        "RFP-Penta/intensity_mean",
        "YFP-DUAL/intensity_mean",
        "area",
        "axis_major_length",
        "axis_minor_length",
    ]
]
bins = np.geomspace(
    measurements_subset[channel].min(), measurements_subset[channel].max(), 100
)
heatmap = measurements_subset.groupby(["t"]).apply(
    lambda x: pd.Series(np.histogram(x[channel], bins=bins)[0], index=bins[:-1])
)
heatmap.columns.name = channel
heatmap = xr.DataArray(heatmap.T)

In [None]:
heatmap.hvplot.quadmesh(
    cmap="blues",
    logy=True,
    logz=True,
    clim=(1, 1e4),
)

# Interactive selections

In [None]:
from holoviews.selection import link_selections

In [None]:
# for weirdness with responsive=True in holoviews/hvplot
# SEE: https://github.com/holoviz/panel/issues/1394

In [None]:
def filter_plots(data, singles, pairs):
    # sel = link_selections.instance()
    return pn.Column(
        pn.pane.HoloViews(
            link_selections(
                hv.Layout(
                    [
                        # hv.Distribution(data, k).opts(height=200, width=200, yaxis="bare")
                        data.hvplot.kde(k, height=200, responsive=True, yaxis="bare")
                        for idx, k in enumerate(singles)
                    ]
                ).cols(6),
            ),
            sizing_mode="stretch_width",
        ),
        pn.pane.HoloViews(
            link_selections(
                hv.Layout(
                    [
                        data.hvplot.scatter(*k, height=300, width=300, hover=False)
                        for k in pairs
                    ]
                ),
            ),
            sizing_mode="stretch_width",
        ),
    )


p = filter_plots(
    all_measurements_subset,
    [
        "RFP-Penta/intensity_mean",
        "YFP-DUAL/intensity_mean",
        "area",
        "axis_minor_length",
        "axis_major_length",
    ],
    [
        ("RFP-Penta/intensity_mean", "YFP-DUAL/intensity_mean"),
        ("axis_minor_length", "axis_major_length"),
        ("area", "RFP-Penta/intensity_mean"),
    ],
)
p

# Image visualizations

In [None]:
%%time
rfp_stacks = stack_crops(array, "crops", 30, "RFP-Penta")

In [None]:
%%time
yfp_stacks = stack_crops(array, "crops", 30, "YFP-DUAL")

## Kymographs

In [None]:
a = rfp_stacks[200]
plt.figure(figsize=(20, 20))
plt.imshow(np.swapaxes(a, 0, 1).reshape(a.shape[1], -1))

In [None]:
a = yfp_stacks[300]
plt.figure(figsize=(20, 20))
plt.imshow(np.swapaxes(a, 0, 1).reshape(a.shape[1], -1))

## Many-trenches viewer

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(pad_unstack([yfp_stacks[i][93] for i in range(330, 370)]))