# Imports

In [None]:
import itertools as it
from functools import partial
from glob import glob
from pathlib import Path

import dask
import distributed
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import scipy.signal
import skimage.measure
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
from dask.diagnostics import ProgressBar

ProgressBar().register()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import paulssonlab.image_analysis.mosaic as mosaic
import paulssonlab.image_analysis.delayed as delayed
import paulssonlab.image_analysis.drift as drift
import paulssonlab.image_analysis.geometry as geometry
import paulssonlab.image_analysis.image as image
import paulssonlab.image_analysis.pipeline as pipeline
import paulssonlab.image_analysis.readers as readers
import paulssonlab.image_analysis.segmentation.watershed as watershed
import paulssonlab.image_analysis.trench_detection as trench_detection
import paulssonlab.image_analysis.util as util
import paulssonlab.image_analysis.workflow as workflow
import paulssonlab.util.core as core
import paulssonlab.util.numeric as numeric
from paulssonlab.image_analysis.ui import RevImage, display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")
# hv.extension("matplotlib")

# Config

In [None]:
# nd2_filename = Path("/home/jqs1/scratch/microscopy/230915/230915_RBS_repressors.nd2")
# nd2_filename = Path("/home/jqs1/scratch/microscopy/230912/230912_bcd_rbses001.nd2")
nd2_filename = workflow.SplitFilename(
    sorted(
        glob(
            # "/home/jqs1/scratch/microscopy/230707/230707_repressilators_restart.nd2.split.a*"
            "/home/jqs1/scratch/microscopy/230830/230830_repressilators.nd2.split.*"
        )
    )
)
assert nd2_filename.files
# nd2_filename = Path("/home/jqs1/scratch/microscopy/231101/231101_FP_calibration.nd2")
# fish_filename = nd2_filename[0].parent / "FISH/real_run"

In [None]:
if isinstance(nd2_filename, workflow.SplitFilename):
    parent_dir = nd2_filename[0].parent
else:
    parent_dir = nd2_filename.parent
fish_filename = parent_dir / "FISH/real_run"
output_dir = parent_dir / "test"
# output_dir.mkdir(exist_ok=True)

In [None]:
nd2 = workflow.get_nd2_reader(nd2_filename)
t_max = nd2.sizes["t"]

In [None]:
nd2.sizes

In [None]:
nd2.metadata["channels"]

# Cluster

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    memory="16GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

In [None]:
cluster.adapt(maximum=20)

# Trench detection

In [None]:
# k1 = 1e-9
# center = image.center_from_shape(nd2.get_frame_2D().shape) - np.array([0, -500])
# img = image.correct_radial_distortion(img, k1=k1, center=center)

In [None]:
def preprocess_func(img):
    return img[600 : img.shape[0] - 600, 1500 : img.shape[1] - 1500]

In [None]:
img = nd2.get_frame_2D(t=0, v=17, c=0)

In [None]:
display_image(img, scale=1, downsample=4)

In [None]:
%%time
# diag = util.tree()
diag = None
rois, info = trench_detection.find_trenches(
    img,
    # angle=np.deg2rad(0.001),
    join_info=False,
    width=12,
    # width_to_line_width_ratio=2,
    # width_to_pitch_ratio=None,
    diagnostics=diag,
)
info

In [None]:
len(rois)

In [None]:
idx = 3100
top = np.asarray(rois.loc[idx, ["top_x", "top_y"]])[::-1]
bottom = np.asarray(rois.loc[idx, ["bottom_x", "bottom_y"]])[::-1]

In [None]:
res = trench_detection.profile.profile_line(
    img, src=top, dst=bottom, linewidth=20, cval=0, order=3
)

In [None]:
plt.imshow(res);

In [None]:
plt.plot(res.mean(axis=0));

In [None]:
offsets = {}
for idx in trange(len(rois)):
    top = np.asarray(rois.loc[idx, ["top_x", "top_y"]])[::-1]
    bottom = np.asarray(rois.loc[idx, ["bottom_x", "bottom_y"]])[::-1]
    res = trench_detection.profile.profile_line(
        img, src=top, dst=bottom, linewidth=20, cval=0, order=3
    )
    offsets[idx] = res

In [None]:
fish_metrics_df = pd.DataFrame.from_dict(
    {
        trench_idx: ary.flatten()
        for trench_idx, ary in fish_metrics.items()
        if ary is not None
    },
    columns=pd.MultiIndex.from_tuples(bit_names, names=["channel", "timepoint"]),
    orient="index",
).rename_axis(index="trench_idx")
fish_metrics_df

In [None]:
fish_metrics_df2 = fish_metrics_df.melt(ignore_index=False).reset_index()

In [None]:
fish_metrics_df2

In [None]:
fish_thresholds = {"GFP": 0.007, "Cy5": 0.005, "Cy7": 0.002}

In [None]:
fish_metrics_df2["ground_truth"] = fish_metrics_df2["value"] > fish_metrics_df2[
    "channel"
].map(fish_thresholds)

In [None]:
(fish_metrics_df2.groupby("trench_idx").sum("ground_truth") == 0).sum()

In [None]:
fish_metrics_df2.groupby("channel").apply(lambda x: x["ground_truth"].sum() / len(x))

In [None]:
(fish_metrics_df2.groupby("channel").sum("ground_truth") == 0).sum()

In [None]:
idx = 200
# idx = 3002

In [None]:
x = fish_stacks0[idx][1:]

In [None]:
fish_metrics_df2[fish_metrics_df2["trench_idx"] == idx]

In [None]:
display_image(image.unstack_multichannel(x), scale=0.999)

In [None]:
y = x - x.min(axis=1)[:, np.newaxis, :, :]

In [None]:
display_image(image.unstack_multichannel(y))

In [None]:
plt.imshow(weighted_mean(y))

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
    )
)

In [None]:
ds = hv.Dataset(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value")

In [None]:
ds.to(hv.Violin, ["timepoint", "ground_truth"]).layout("channel").opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
        axiswise=True,
    )
).cols(1)

In [None]:
z = ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")

In [None]:
_stacked_violins = (
    ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")
)

hv.Layout([v.redim(value=k) for k, v in _stacked_violins.items()]).opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        # violin_width=3,
        inner=None,
        bandwidth=0.2,
        cut=0.05,
    )
).cols(1)

In [None]:
# fish_metrics_df2.groupby("channel").apply(lambda x: hv.Violin(x))

In [None]:
hv.Layout()

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(logy=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Dataset(df, ["ground_truth"], "value").to(
            hv.Distribution
        )
        # .overlay("ground_truth")
        # hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(show_legend=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): (
            hv.Distribution(df[df["ground_truth"]], "value", label="On")
            * hv.Distribution(df[~df["ground_truth"]], "value", label="Off")
        ).redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
).opts(hv.opts.Distribution(show_legend=True, bandwidth=0.3, cut=0.05))

In [None]:
from bokeh.sampledata.iris import flowers
from holoviews.operation import gridmatrix

iris_ds = hv.Dataset(flowers)

In [None]:
iris_ds

In [None]:
fish_metrics_df3 = fish_metrics_df.set_axis(
    ["_".join(c) for c in fish_metrics_df.columns], axis=1
)

In [None]:
fish_metrics_df3 = fish_metrics_df3.loc[
    :, [*fish_metrics_df3.columns[3:6], *fish_metrics_df3.columns[13:16]]
]

In [None]:
density_grid = gridmatrix(
    hv.Dataset(fish_metrics_df3), diagonal_type=hv.Distribution, chart_type=hv.Bivariate
)

In [None]:
density_grid