# Imports

In [None]:
import glob
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm

IDX = pd.IndexSlice

In [None]:
from dask.diagnostics import ProgressBar

ProgressBar().register()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import *
from paulssonlab.image_analysis.ui import display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
# filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"
# filename = workflow.SplitFilename(
#     sorted(
#         glob.glob(
#             "/home/jqs1/scratch/jqs1/microscopy/230619/230619_NAO745_repressilators_split.nd2*"
#         )
#     )
# )
filename = workflow.SplitFilename(
    sorted(
        glob.glob(
            "/home/jqs1/scratch/jqs1/microscopy/230707/230707_repressilators_restart.nd2.split.a*"
        )
    )
)
fish_filename = Path(filename).parent / "FISH/real_run"

In [None]:
nd2 = workflow.get_nd2_reader(filename)
t_max = nd2.sizes["t"]

In [None]:
colors = {
    "BF": "#ffffff",
    "CFP-EM": "#f44336",  # TODO
    "YFP-EM": "#03a9f4",
    "RFP-EM": "#8bc34a",
}

fish_colors = {
    "BF": "#ffffff",
    "GFP": "#f44336",
    "Cy5": "#03a9f4",
    # "Cy7": "#ffeb3b"
    "Cy7": "#8bc34a",
}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    memory="16GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(10)

In [None]:
cluster.adapt(maximum=20)

# Handler

In [None]:
segmentation_channel = "RFP-PENTA"
trench_detection_channel = segmentation_channel  # channel for trench detection, almost always same as segmentation_channel
measure_channels = ["RFP-PENTA", "YFP-DUAL"]
fish_channels = ["RFP-PENTA", "Cy5-PENTA", "Cy7"]

In [None]:
def crop_rois(img, rois):
    crops = {}
    # TODO: the islice is just for testing (we only deal with three trenches for FOV), otherwise every dask task takes a long time
    # for i, crop in it.islice(geometry.iter_roi_crops(img, rois), 100):
    for i, crop in geometry.iter_roi_crops(img, rois):
        crops[i] = crop
    return crops


def segment_crops(crops):
    masks = {}
    for i, crop in crops.items():
        masks[i] = segmentation.watershed.segment(crop)
    return masks


# TODO: this is really boilerplatey, also we want finer task granularity than doing a whole FOV at once
# def measure_crops(label_images, intensity_images):
#     keys = label_images.keys() & intensity_images.keys()
#     return {k: measure_crop(label_images[k], intensity_images[k]) for k in keys}
def measure_crops(intensity_images):
    keys = intensity_images.keys()
    return {k: measure_crop(intensity_images[k]) for k in keys}


# def measure_crop(label_image, intensity_image):
# return pd.DataFrame(
#     skimage.measure.regionprops_table(
#         label_image,
#         intensity_image,
#         properties=(
#             "label",
#             "intensity_mean",
#         ),
#     )
# ).set_index("label")
def measure_crop(intensity_image):
    centerline = intensity_image[:, intensity_image.shape[1] // 2]
    return pd.Series(
        {
            "p1": np.percentile(intensity_image, 1),
            "p50": np.median(intensity_image),
            "p90": np.percentile(intensity_image, 90),
            "p99": np.percentile(intensity_image, 99),
            "mean": np.mean(intensity_image),
            "centerline_mean": np.mean(centerline),
            "centerline_median": np.median(centerline),
        },
        name="value",
    ).rename_axis(index="observable")


def measure_mask_crops(label_images):
    return {k: measure_mask_crop(v) for k, v in label_images.items()}


def measure_mask_crop(label_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            properties=(
                "label",
                "area",
                "axis_major_length",
                "axis_minor_length",
                "orientation",
                "centroid",
            ),
        )
    ).set_index("label")


def write_parquet(output_dir, measurements, position, t):
    df = pd.concat(
        {
            channel: pd.concat(channel_df, names=["roi_idx"])
            for channel, channel_df in measurements.items()
        },
        names=["channel"],
    ).reset_index()
    df["position"] = np.array(position).astype(np.uint16)
    df["t"] = np.array(t).astype(np.uint16)
    pq.write_to_dataset(
        pa.Table.from_pandas(df, preserve_index=False),
        Path(output_dir) / "measurements",
        partition_cols=["position", "t"],
        existing_data_behavior="delete_matching",
    )


def stack_dict(d):
    shape = next(iter(d.values())).shape
    null = np.full(shape, np.nan)
    return [d.get(idx, null) for idx in range(max(d.keys()) + 1)]


def _pad(ary, shape):
    return np.pad(
        ary, [(0, max(goal - current, 0)) for goal, current in zip(shape, ary.shape)]
    )


def write_zarr(filename, crops, t, max_t, channels):
    store = zarr.DirectoryStore(filename)  # DirectoryStoreV3(filename)
    if not filename.exists():
        num_rois = max(crops[channels[0]].keys()) + 1
        num_channels = len(channels)
        max_shape = np.max([crop.shape for crop in crops[channels[0]].values()], axis=0)
        shape = (num_rois, max_t, num_channels, *max_shape)
        chunks = (5, 1, num_channels, None, None)
        ary = zarr.open_array(
            store,
            mode="a",
            zarr_version=2,
            shape=shape,
            chunks=chunks,
            fill_value=np.nan,
        )
    else:
        ary = zarr.open_array(store, mode="a", zarr_version=2)
        max_shape = ary.shape[-2:]
    stack = np.array(
        [
            stack_dict(
                {
                    idx: _pad(crop.astype(np.float32), max_shape)
                    for idx, crop in crops[channel].items()
                }
            )
            for channel in channels
        ]
    ).swapaxes(0, 1)
    ary[:, t, ...] = stack

In [None]:
output_dir = Path(filename).parent / "test_output"
output_dir.mkdir(exist_ok=True)

In [None]:
segmentation_channel = "RFP-EM"
measurement_channels = ["CFP-EM", "YFP-EM", "RFP-EM"]
width_to_pitch_ratio = 1.4 / 3.5
k1 = 8.947368421052635e-10

In [None]:
def get_frame_func(filename, position, channel, t, k1=k1, dark=None, flat=None):
    return image.correct_radial_distortion(
        np.asarray(
            workflow.get_nd2_frame(
                filename, position=position, channel=channel, t=t, dark=dark, flat=flat
            )
        ),
        k1=k1,
    )[550:2350, 1500:3500]
    # return np.asarray(
    #     workflow.get_nd2_frame(filename, position, channel, t, dark=dark, flat=flat)
    # )[550:2350, 1500:3500]

In [None]:
%%time
img0 = get_frame_func(filename, 11, segmentation_channel, 0)
image_limits = geometry.get_image_limits(img0.shape)

In [None]:
display_image(img0, scale=0.9, downsample=4)

In [None]:
%%time
diag = util.tree()
rois, info = trench_detection.find_trenches(
    img0,
    width_to_pitch_ratio=width_to_pitch_ratio,
    join_info=False,
    diagnostics=diag,
)
angle = info["angle"]
pitch = info["pitch"]

In [None]:
diag["bboxes"]

In [None]:
diag["labeling"]["binarize_trench_image"].keys()

In [None]:
diag["labeling"]["binarize_trench_image"]["num_components"]

In [None]:
diag["labeling"]["binarize_trench_image"].keys()

In [None]:
diag["labeling"]["binarize_trench_image"]["normalized_image"]

In [None]:
def process_fov(
    get_frame_func,
    position,
    ts,
    output_dir,
    segmentation_channel,
    measurement_channels,
    image_limits,
    find_trenches_kwargs={},
    dark=None,
    flats=None,
    delayed=True,
):
    delayed = util.get_delayed(delayed)
    channels = [
        segmentation_channel,
        *(set(measurement_channels) - set([segmentation_channel])),
    ]
    measurement_channels = measurement_channels
    rois = None
    shifts = {}
    write_tasks = []
    for prev_t, t in tqdm(list(zip(it.chain([None], ts[:-1]), ts))):
        segmentation_img = delayed(get_frame_func)(position, segmentation_channel, t)
        if rois is None:
            rois = delayed(trench_detection.find_trenches)(
                segmentation_img, **{**dict(join_info=True), **find_trenches_kwargs}
            )
            shifts[t] = np.array([0, 0])
            initial_drift_features = delayed(drift.get_drift_features)(
                segmentation_img, rois, shifts[t]
            )
        else:
            shifts[t] = delayed(drift.find_feature_drift)(
                initial_drift_features,
                segmentation_img,
                rois,
                initial_shift2=shifts[prev_t],
            )
        shifted_rois = delayed(geometry.filter_rois)(
            delayed(geometry.shift_rois)(rois, shifts[t]), image_limits
        )
        crops = {}
        measurements = {}
        for channel in channels:
            if channel == segmentation_channel:
                crops[channel] = delayed(crop_rois)(segmentation_img, shifted_rois)
                # mask_crops = delayed(segment_crops)(crops[channel])
                # mask_measurements = delayed(measure_mask_crops)(mask_crops)
            else:
                img = delayed(get_frame_func)(position, channel, t)
                crops[channel] = delayed(crop_rois)(img, shifted_rois)
            if channel in measurement_channels:
                # measurements[channel] = delayed(measure_crops)(mask_crops, crops[channel])
                measurements[channel] = delayed(measure_crops)(crops[channel])
        metadata = dict(shifts=shifts)
        write_tasks.append(
            delayed(write_parquet)(output_dir, measurements, position, t)
        )
        # TODO
        max_t = 300
        write_tasks.append(
            delayed(write_zarr)(
                output_dir / f"crops_v={position}.zarr",
                crops,
                t,
                max_t,
                measurement_channels,
            )
        )
        # TODO: rois, metadata
    return write_tasks

In [None]:
%%time
# %%pyinstrument
ts = np.arange(t_max)
# ts = np.arange(2)
res = []
for position in np.arange(13, 40):
    res.append(
        process_fov(
            partial(get_frame_func, filename),
            position,
            ts,
            output_dir / "test2",
            segmentation_channel,
            measurement_channels,
            image_limits,
            find_trenches_kwargs=dict(
                angle=angle, pitch=pitch, width_to_pitch_ratio=2.2 / 3.5
            ),
            delayed=True,
        )
    )

In [None]:
res[0][0]

In [None]:
%%time
futures = [client.compute(x) for x in tqdm(res)]

In [None]:
del futures

In [None]:
client.gather(futures)

In [None]:
futures

In [None]:
!ls $output_dir/test1

In [None]:
!rm -rf $output_dir/test2

In [None]:
dataset = ds.dataset(
    output_dir / "test2/measurements", format="parquet", partitioning="hive"
)

In [None]:
%%time
df = dataset.to_table(filter=ds.field("position") == 11).to_pandas()

In [None]:
df.info(verbose=True, memory_usage="deep")

In [None]:
%%time
df2 = (
    df[df["observable"] == "p90"]
    .pivot_table(
        columns=["channel"], values=["value"], index=["position", "roi_idx", "t"]
    )
    .droplevel(0, axis=1)
    # .droplevel(["position"])
    # .reset_index("roi_idx")
)

In [None]:
df2.info(verbose=True, memory_usage="deep")

In [None]:
df2

In [None]:
list(df2.groupby(["position", "roi_idx"]))[1]

In [None]:
def norm(x, quantile=0.9):
    if quantile is None or quantile == 0:
        return x
    else:
        return x / x.quantile(quantile)

In [None]:
x = norm(df2.loc[(11, 21)].loc[IDX[40:], :])

In [None]:
pd.plotting.autocorrelation_plot(x["CFP-EM"])

In [None]:
df2.reset_index()[df2.reset_index()["roi_idx"] == 244]

In [None]:
norm(df2.loc[(11, 244)].loc[IDX[40:], :]).hvplot()

In [None]:
(df2.index.get_level_values("roi_idx") == 244).sum()

In [None]:
roi_idxs = np.unique(
    df2.index.get_level_values("roi_idx")
)  # df2.index.levels[df2.index.names.index("roi_idx")]

In [None]:
groups = list(df2.groupby(["position", "roi_idx"], as_index=False, group_keys=False))

In [None]:
norm(groups[5][1].droplevel(["position", "roi_idx"])).hvplot()

In [None]:
hv.HoloMap(
    {
        t: norm(
            groups[t][1].droplevel(["position", "roi_idx"]).loc[IDX[40:], :]
        ).hvplot()
        for t in range(2)
    }
)

In [None]:
# 21, 57, 67, 103, 105, 107, 116, 149, 162, 170, 185, 191, 215, 237, 246, 252, 268, 285, 302, 319, 321, 342, 346, 375, 417, 432, 453, 454, 457, 462, 463
# 535, 567, 588, 600, 638, 644, 650, 677, 680, 690, 707

In [None]:
hv.HoloMap(
    {
        idx[1]: norm(group.droplevel(["position", "roi_idx"]).loc[IDX[40:], :]).hvplot()
        for idx, group in groups[400:]
    }
)

In [None]:
z = zarr.open_array(output_dir / "test2/crops_v=11.zarr", mode="r")

In [None]:
z.shape

In [None]:
plt.imshow(z[1][0][1].T)

In [None]:
measurement_channels

In [None]:
t_max0 = 20
(
    hv.HoloMap({t: ui.RevImage(z[21][t][0].T) for t in range(t_max0)}).options(
        frame_width=400
    )
    + hv.HoloMap({t: ui.RevImage(z[21][t][1].T) for t in range(t_max0)}).options(
        frame_width=400
    )
    + hv.HoloMap({t: ui.RevImage(z[21][t][2].T) for t in range(t_max0)}).options(
        frame_width=400
    )
).cols(1).opts(hv.opts.Image(axiswise=True), hv.opts.Layout())

In [None]:
%%output backend='matplotlib'
# %%opts Layout [normalize=False fig_inches=2 vspace=0 aspect_weight=1 sublabel_format='' tight=True title_format="{filename:}\npos: {position:} trench: {trench_set:}.{trench:} t: {t:}".format(**label_stream.contents) fontsize=20]
# %%opts Scatter [aspect=6]
key = tuple(getattr(label_stream, attr) for attr in trench_key)
index = detected_bursts.groupby(trench_key).get_group(key).index
ts = index._get_level_values(index._get_level_number("t"), unique=True)
# ts = list(range(3))

movie = (
    trench_movie(trench_bboxes, key, "MCHERRY", ts)
    + trench_movie(trench_bboxes, key, "YFP", ts)
    + scatter_movie(labelwise_df, label_stream.contents, ts)
    * hv.HoloMap(
        {t: hv.VLine(t).options(color="red", backend="matplotlib") for t in ts}
    )
).cols(1)
movie2 = movie.options(
    {
        "Layout": dict(
            normalize=False,
            framewise=True,
            fig_inches=7,
            vspace=0,
            aspect_weight=1,
            sublabel_format="",
            tight=False,
            fontsize=15,
            title_format="{filename:}\npos: {position:} trench: {trench_set:}.{trench:} t: {t:}".format(
                **label_stream.contents
            ),
        ),
        "Scatter": dict(aspect=6, s=20),
    },
    backend="matplotlib",
)
m = holomap_to_video(movie2, out="/tmp/jqsmovie.mp4", size=100, dpi=100)

In [None]:
import os

os.environ["ZARR_V3_EXPERIMENTAL_API"] = "1"
os.environ["ZARR_V3_SHARDING"] = "1"

from zarr._storage.v3 import DirectoryStoreV3
from zarr._storage.v3_storage_transformers import ShardingStorageTransformer

# Manual FISH trench crops

In [None]:
nd2 = nd2reader.ND2Reader(filename)

In [None]:
nd2.sizes

In [None]:
img = nd2.get_frame_2D(v=8, c=0, t=180)

In [None]:
k1 = 8.947368421052635e-10
img_t = image.correct_radial_distortion(img, k1=k1)

In [None]:
%%time
# diag = util.tree()
diag = None
trenches, info = trench_detection.find_trenches(
    img_t,
    # angle=np.deg2rad(0.001),
    join_info=False,
    width=12,
    # width_to_line_width_ratio=2,
    # width_to_pitch_ratio=None,
    # peak_func=trench_detection.peaks.find_peaks,
    diagnostics=diag,
)

In [None]:
def crop_trenches(img, trenches):
    crops = {}
    # for i, crop in it.islice(geometry.iter_crops(img, trenches), 10, 13):
    for i, crop in geometry.iter_crops(img, trenches):
        crops[i] = crop
    return crops


def stack_crops(crops, channels, timepoints):
    stacks = {}
    for (t, channel), frame_crops in crops.items():
        channel_idx = channels.index(channel)
        timepoint_idx = timepoints.index(t)
        for trench_idx, trench_slice in frame_crops.items():
            if trench_idx not in stacks:
                stacks[trench_idx] = zarr.create(
                    (len(channels), len(timepoints), *trench_slice.shape),
                    dtype=trench_slice.dtype,
                    fill_value=np.nan,
                )
            stacks[trench_idx][channel_idx, timepoint_idx, :, :] = trench_slice
    return stacks

In [None]:
def calibrate_image(img, k1=0):
    img = skimage.img_as_float32(img)
    img = image.correct_radial_distortion(img, k1=k1)
    return img

In [None]:
%%time
delayed = util.get_delayed(True)
fish_frames = {}
fish_crops = {}
fish_channels = set()
fish_timepoints = set()
for msg in readers.send_eaton_fish(
    fish_filename,
    r"fov=(?P<v>\d+)_config=(?P<c>\w+)_t=(?P<t>\d+)",
    slices=dict(t=None, v=[8]),
    delayed=delayed,
):
    # print(msg["metadata"],msg["image"].shape)
    fish_img = msg["image"]
    fish_img_corrected = delayed(calibrate_image)(fish_img, k1=k1)
    fov = msg["metadata"]["fov"]
    t = msg["metadata"]["t"]
    channel = msg["metadata"]["channel"]
    fish_channels.add(channel)
    fish_timepoints.add(t)
    fish_frames[(t, channel)] = fish_img_corrected
    fish_crops[(t, channel)] = delayed(crop_trenches)(fish_img_corrected, trenches)
fish_channels = list(sorted(fish_channels))
fish_timepoints = list(sorted(fish_timepoints))
fish_stacks = delayed(stack_crops)(fish_crops, fish_channels, fish_timepoints)

In [None]:
fish_channel_colors = [fish_colors[ch] for ch in fish_channels]

In [None]:
fish_frames0, fish_stacks0 = dask.compute(fish_frames, fish_stacks)

In [None]:
fish_channels

In [None]:
fish_timepoints

In [None]:
fish_stacks0[10].info

In [None]:
for msg in new.readers.send_nd2(
    filename,
    slices=dict(v=slice(1), t=slice(1)),
):
    handle_message(pipeline, msg)

# Segmentation

In [None]:
x = fish_stacks0[13][1:, :9]
# x = x - x.min(axis=1)[:,np.newaxis,:,:]

In [None]:
def weighted_mean(ary):
    ary = ary - ary.min(axis=1)[:, np.newaxis, :, :]
    # lmbda = (ary.max(axis=1) - ary.min(axis=1))[:,np.newaxis,:,:]
    lmbda = ary.max(axis=1)[:, np.newaxis, :, :]
    w = (
        1
        / 3
        * (lmbda / lmbda.sum(axis=(2, 3))[:, :, np.newaxis, np.newaxis]).sum(axis=0)[
            np.newaxis, :, :, :
        ]
    )
    if w.sum() == 0:
        return None
    return np.average(ary, axis=(2, 3), weights=np.broadcast_to(w, ary.shape))

In [None]:
fish_metrics = {
    idx: weighted_mean(np.asarray(stack[1:, :9]))
    for idx, stack in tqdm(fish_stacks0.items())
}

In [None]:
sum(1 for x in fish_metrics.values() if x is None)

In [None]:
fish_metrics[1][0]

In [None]:
bit_names = [(ch, str(t)) for ch in fish_channels[1:] for t in fish_timepoints[:-1]]

In [None]:
fish_metrics_df = pd.DataFrame.from_dict(
    {
        trench_idx: ary.flatten()
        for trench_idx, ary in fish_metrics.items()
        if ary is not None
    },
    columns=pd.MultiIndex.from_tuples(bit_names, names=["channel", "timepoint"]),
    orient="index",
).rename_axis(index="trench_idx")
fish_metrics_df

In [None]:
fish_metrics_df2 = fish_metrics_df.melt(ignore_index=False).reset_index()

In [None]:
fish_metrics_df2

In [None]:
fish_thresholds = {"GFP": 0.007, "Cy5": 0.005, "Cy7": 0.002}

In [None]:
fish_metrics_df2["ground_truth"] = fish_metrics_df2["value"] > fish_metrics_df2[
    "channel"
].map(fish_thresholds)

In [None]:
(fish_metrics_df2.groupby("trench_idx").sum("ground_truth") == 0).sum()

In [None]:
fish_metrics_df2.groupby("channel").apply(lambda x: x["ground_truth"].sum() / len(x))

In [None]:
(fish_metrics_df2.groupby("channel").sum("ground_truth") == 0).sum()

In [None]:
# idx = 1901
idx = 3002

In [None]:
x = fish_stacks0[idx][1:, :9]

In [None]:
fish_metrics_df2[fish_metrics_df2["trench_idx"] == idx]

In [None]:
display_image(image.unstack_multichannel(x))

In [None]:
y = x - x.min(axis=1)[:, np.newaxis, :, :]

In [None]:
display_image(image.unstack_multichannel(y))

In [None]:
plt.imshow(weighted_mean(y))

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
    )
)

In [None]:
ds = hv.Dataset(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value")

In [None]:
ds.to(hv.Violin, ["timepoint", "ground_truth"]).layout("channel").opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
        axiswise=True,
    )
).cols(1)

In [None]:
z = ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")

In [None]:
_stacked_violins = (
    ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")
)

hv.Layout([v.redim(value=k) for k, v in _stacked_violins.items()]).opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        # violin_width=3,
        inner=None,
        bandwidth=0.2,
        cut=0.05,
    )
).cols(1)

In [None]:
# fish_metrics_df2.groupby("channel").apply(lambda x: hv.Violin(x))

In [None]:
hv.Layout()

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(logy=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Dataset(df, ["ground_truth"], "value").to(
            hv.Distribution
        )
        # .overlay("ground_truth")
        # hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(show_legend=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): (
            hv.Distribution(df[df["ground_truth"]], "value", label="On")
            * hv.Distribution(df[~df["ground_truth"]], "value", label="Off")
        ).redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
).opts(hv.opts.Distribution(show_legend=True, bandwidth=0.3, cut=0.05))

In [None]:
from bokeh.sampledata.iris import flowers
from holoviews.operation import gridmatrix

iris_ds = hv.Dataset(flowers)

In [None]:
iris_ds

In [None]:
fish_metrics_df3 = fish_metrics_df.set_axis(
    ["_".join(c) for c in fish_metrics_df.columns], axis=1
)

In [None]:
fish_metrics_df3 = fish_metrics_df3.loc[
    :, [*fish_metrics_df3.columns[3:6], *fish_metrics_df3.columns[13:16]]
]

In [None]:
density_grid = gridmatrix(
    hv.Dataset(fish_metrics_df3), diagonal_type=hv.Distribution, chart_type=hv.Bivariate
)

In [None]:
density_grid