# Imports

In [None]:
import glob
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import deltalake
import distributed
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import scipy.signal
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
from dask.diagnostics import ProgressBar

ProgressBar().register()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import paulssonlab.image_analysis.mosaic as mosaic
import paulssonlab.image_analysis.delayed as delayed
import paulssonlab.image_analysis.drift as drift
import paulssonlab.image_analysis.geometry as geometry
import paulssonlab.image_analysis.image as image
import paulssonlab.image_analysis.pipeline as pipeline
import paulssonlab.image_analysis.readers as readers
import paulssonlab.image_analysis.segmentation.watershed as watershed
import paulssonlab.image_analysis.trench_detection as trench_detection
import paulssonlab.image_analysis.util as util
import paulssonlab.image_analysis.workflow as workflow
import paulssonlab.util.core as core
import paulssonlab.util.numeric as numeric
from paulssonlab.image_analysis.ui import RevImage, display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")
# hv.extension("matplotlib")

# Config

In [None]:
nd2_filename = Path("/home/jqs1/scratch/microscopy/230915/230915_RBS_repressors.nd2")
# nd2_filename = workflow.SplitFilename(
#     sorted(
#         glob.glob(
#             # "/home/jqs1/scratch/jqs1/microscopy/230707/230707_repressilators_restart.nd2.split.a*"
#             "/home/jqs1/scratch/microscopy/230830/230830_repressilators.nd2.split.*"
#         )
#     )
# )
# assert nd2_filename.files
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/231101/231101_FP_calibration.nd2"
# fish_filename = nd2_filename[0].parent / "FISH/real_run"

In [None]:
if isinstance(nd2_filename, workflow.SplitFilename):
    parent_dir = nd2_filename[0].parent
else:
    parent_dir = nd2_filename.parent
fish_filename = parent_dir / "FISH/real_run"
output_dir = parent_dir / "test2"
# output_dir.mkdir(exist_ok=True)

In [None]:
nd2 = workflow.get_nd2_reader(nd2_filename)
t_max = nd2.sizes["t"]

In [None]:
nd2.sizes

In [None]:
nd2.metadata["channels"]

In [None]:
colors = {
    "BF": "#ffffff",
    # "CFP-EM": "#6fb2e4",
    # "YFP-EM": "#eee461",
    # "RFP-EM": "#c66526",
    "CFP-EM": "#648FFF",
    "YFP-EM": "#FFB000",
    "RFP-EM": "#DC267F",
}

fish_colors = {
    "BF": "#ffffff",
    "GFP": "#f44336",
    "Cy5": "#03a9f4",
    # "Cy7": "#ffeb3b"
    "Cy7": "#8bc34a",
}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    memory="16GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

In [None]:
cluster.adapt(maximum=20)

# Pipeline

In [None]:
# k1 = 1e-9
# center = image.center_from_shape(nd2.get_frame_2D().shape) - np.array([0, -500])

In [None]:
# def get_frame_func(
#     filename, position, channel, t, k1=None, center=None, dark=None, flat=None
# ):
#     img = np.asarray(
#         workflow.get_nd2_frame(
#             filename, position=position, channel=channel, t=t, dark=dark, flat=flat
#         )
#     )
#     if k1 is not None:
#         img = image.correct_radial_distortion(img, k1=k1, center=center)
#     # TODO
#     img = img[:, 300 : img.shape[1] - 300]
#     return img


# def preprocess_func(img, k1=None, center=None, dark=None, flat=None):
def preprocess_func(img):
    return img[600 : img.shape[0] - 600, 1500 : img.shape[1] - 1500]

In [None]:
config = {
    # "composite_func": image.mean_composite,
    # "roi_detection_func": trench_detection.find_trenches,
    # "track_drift": True,
    # "segmentation_func": watershed.watershed_segment,
    # "segmentation_channels": ["RFP-EM", "YFP-EM", "CFP-EM"],
    "segmentation_channels": ["RFP-EM"],
    "trench_detection_channels": None,  # channel for trench detection, almost always same as segmentation_channel
    # "measure_channels": = ["RFP-PENTA", "YFP-DUAL"],
    # "crop_channels": ["Phase-Fluor", "RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    "crop_channels": ["RFP-EM", "YFP-EM"],
    # "measure_channels": ["RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    "measure_channels": ["RFP-EM", "YFP-EM"],
    # "fish_crop_channels": ["BF", GFP-EM", "Cy5", "Cy7"],
    "fish_measure_channels": ["GFP", "Cy5", "Cy7"],
    "fish_drift_tracking_channel": "BF",
    # "fish_probes": hhh,
    # "roi_detection_kwargs": {"width_to_pitch_ratio": 1.4 / 3.5},
    "roi_detection_kwargs": {"width_to_pitch_ratio": 2.2 / 3.5},  # TODO!!!
    "preprocess_func": preprocess_func,
    # "preprocess_kwargs": {"k1": k1, "center": center},
}

In [None]:
%%time
delayed_ = False

p = pipeline.DefaultPipeline(output_dir, config=config, delayed=delayed_)

for msg in readers.send_nd2(
    nd2_filename,
    # slices=dict(t=range(200, 202), v=range(8,10)),
    slices=dict(t=range(61, 62), v=[8]),
    delayed=delayed_,
):
    p.handle_message({**msg, "image_type": "science"})

In [None]:
%%time
for msg in readers.send_eaton_fish(
    fish_filename,
    # slices=dict(t=None, v=range(8, 10)),
    slices=dict(t=None, v=[8]),
    delayed=delayed_,
):
    p.handle_message({**msg, "image_type": "fish_barcode"})

In [None]:
p.handle_message({"type": "done"})

# FISH drift tracking

In [None]:
bfs = {}
for msg in readers.send_eaton_fish(
    fish_filename,
    # slices=dict(t=None, v=range(8, 10)),
    slices=dict(t=None, v=[8]),
    delayed=delayed_,
):
    if msg["metadata"]["channel"] != "BF":
        continue
    t = msg["metadata"]["t"]
    bfs[t] = msg["image"]

In [None]:
display_image(bfs[1], scale=1)

In [None]:
display_image(bfs[5] - bfs[6], scale=1)

In [None]:
skimage.registration.phase_cross_correlation(bfs[1], bfs[5], normalization=None)

In [None]:
crop = (slice(100, 300), slice(100, 300))
skimage.registration.phase_cross_correlation(
    bfs[1][crop], bfs[5][crop], normalization=None, upsample_factor=10
)

In [None]:
skimage.registration.phase_cross_correlation(
    bfs[4][crop], bfs[5][crop], normalization=None, upsample_factor=10
)

In [None]:
p.fish_first_t

In [None]:
shifts = {t: p.fish_shifts.value[(8, t)].result() for t in range(2, 11)}
shifts[1] = np.array([0, 0])

In [None]:
shifts

In [None]:
shifts

In [None]:
hv.HoloMap(
    {
        t: hv.Image(
            bfs[t][
                100 - shifts[t][0] : 300 - shifts[t][0],
                100 - shifts[t][1] : 300 - shifts[t][1],
            ]
        )
        for t in range(1, 11)
    }
)

In [None]:
hv.HoloMap(
    {
        t: hv.Image(
            bfs[t][
                100 - shifts[t][1] : 300 - shifts[t][1],
                100 - shifts[t][0] : 300 - shifts[t][0],
            ]
        )
        for t in range(1, 11)
    }
)

In [None]:
hv.HoloMap(
    {
        t: hv.Image(
            bfs[t][
                100 + shifts[t][1] : 300 + shifts[t][1],
                100 + shifts[t][0] : 300 + shifts[t][0],
            ]
        )
        for t in range(1, 11)
    }
)

In [None]:
hv.HoloMap(
    {
        t: hv.Image(
            bfs[t][
                100 + shifts[t][0] : 300 + shifts[t][0],
                100 + shifts[t][1] : 300 + shifts[t][1],
            ]
        )
        for t in range(1, 11)
    }
)

In [None]:
hv.HoloMap({t: hv.Image(bfs[t][100:300, 100:300]) for t in range(1, 11)})

In [None]:
display_image(skimage.filters.sobel(bfs[4][100:300, 100:300]), scale=1)

In [None]:
new_shifts = {}
new_shifts[1] = np.array([0, 0])
for t in trange(2, 11):
    new_shifts[t] = skimage.registration.phase_cross_correlation(bfs[1], bfs[t])[
        0
    ].astype(np.int16)[::-1]

In [None]:
new_shifts

In [None]:
hv.HoloMap(
    {
        t: hv.Image(
            bfs[t][
                100 - new_shifts[t][1] : 300 - new_shifts[t][1],
                100 - new_shifts[t][0] : 300 - new_shifts[t][0],
            ]
        )
        for t in range(1, 11)
    }
)

In [None]:
hv.HoloMap(
    {
        t: hv.Image(bfs[t][100:300, 100:300])
        + hv.Image(bfs[t][-300:-100, -300:-100])
        + hv.Image(bfs[t][100:300, -300:-100])
        + hv.Image(bfs[t][-300:-100, 100:-100])
        for t in range(1, 11)
    }
)

In [None]:
sk

In [None]:
display_image(
    image.unstack_multichannel(fish_crops[1:, 8, :, 39].squeeze().swapaxes(0, 1)),
    scale=1,
)  # , colors=fish_colors)

In [None]:
display_image(
    image.unstack_multichannel(fish_crops[1:, 8, :, 390].squeeze().swapaxes(0, 1)),
    scale=1,
)  # , colors=fish_colors)

In [None]:
display_image(
    image.unstack_multichannel(fish_crops[1:, 8, :, 39].squeeze().swapaxes(0, 1)),
    scale=1,
)  # , colors=fish_colors)

In [None]:
display_image(
    image.unstack_multichannel(fish_crops[1:, 8, :, 39].squeeze().swapaxes(0, 1)),
    scale=1,
)  # , colors=fish_colors)

# FISH analysis

In [None]:
fish_crops = readers.ZarrSlicer(
    output_dir / "fish_crops",
    r"fov=(?P<v>\d+)/channel=(?P<c>[^/]+)/t=(?P<t>\d+)",
    files=False,
    recursive=True,
)

In [None]:
pd.read_parquet(output_dir / "fish_measurements")

In [None]:
x = pl.scan_parquet(output_dir / "fish_measurements/*/*")

In [None]:
x = pl.read_delta(str(output_dir / "fish_measurements"))

In [None]:
y = deltalake.DeltaTable(str(output_dir / "fish_measurements")).to_pandas()

In [None]:
dataset = ds.dataset(
    output_dir / "fish_measurements", format="parquet", partitioning="hive"
)
# df = dataset.to_table(filter=ds.field("position") == 14).to_pandas()
df = dataset.to_table().to_pandas().sort_values("t")

In [None]:
df

In [None]:
plt.scatter(df["mean"], df["otsu_mean"], s=1)

In [None]:
hv.Violin(df, ["channel", "t"], "otsu_mean").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(df, ["channel", "t"], "mean").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(df2, ["channel", "t", "call"], "otsu_mean").opts(
    hv.opts(
        width=700,
        show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("call"),
        violin_width=3,
        inner=None,
    )
)

In [None]:
p.fish_crops.keys()

In [None]:
fish_thresholds = {"GFP": 1200, "Cy5": 1000, "Cy7": 800}

In [None]:
df2 = df.assign(call=df["otsu_mean"] > df["channel"].map(fish_thresholds))

In [None]:
df2[(df2["fov_num"] == 8) & (df2["roi"] == 40) & (df2["channel"] == "Cy7")]

In [None]:
display_image(
    image.unstack(image.crop_to_mask(fish_crops[:, 8, "Cy7", 390].squeeze())), scale=1
)

In [None]:
display_image(
    image.unstack_multichannel(
        image.crop_to_mask(fish_crops[1:, 8, :, 39].squeeze().swapaxes(0, 1))
    ),
    scale=1,
)  # , colors=fish_colors)

## Join with codebook

In [None]:
seq = pl.read_ipc(
    "/home/jqs1/scratch/sequencing/230818_repressilators/20230905_1132_1H_PAQ85679_c9d74ddb/extract_segments/*.arrow"
)

In [None]:
seq

# Science analysis

In [None]:
crops = readers.ZarrSlicer(
    output_dir / "crops",
    r"fov=(?P<v>\d+)/channel=(?P<c>[^/]+)/t=(?P<t>\d+)",
    files=False,
    recursive=True,
)

In [None]:
display_image(
    image.unstack(image.crop_to_mask(crops[:, 8, "YFP-EM", 33].squeeze())), scale=1
)

In [None]:
crops[:, 8, "YFP-EM", 3].squeeze().shape

# Old FISH analysis

In [None]:
fish_metrics_df = pd.DataFrame.from_dict(
    {
        trench_idx: ary.flatten()
        for trench_idx, ary in fish_metrics.items()
        if ary is not None
    },
    columns=pd.MultiIndex.from_tuples(bit_names, names=["channel", "timepoint"]),
    orient="index",
).rename_axis(index="trench_idx")
fish_metrics_df

In [None]:
fish_metrics_df2 = fish_metrics_df.melt(ignore_index=False).reset_index()

In [None]:
fish_metrics_df2

In [None]:
fish_thresholds = {"GFP": 0.007, "Cy5": 0.005, "Cy7": 0.002}

In [None]:
fish_metrics_df2["ground_truth"] = fish_metrics_df2["value"] > fish_metrics_df2[
    "channel"
].map(fish_thresholds)

In [None]:
(fish_metrics_df2.groupby("trench_idx").sum("ground_truth") == 0).sum()

In [None]:
fish_metrics_df2.groupby("channel").apply(lambda x: x["ground_truth"].sum() / len(x))

In [None]:
(fish_metrics_df2.groupby("channel").sum("ground_truth") == 0).sum()

In [None]:
idx = 200
# idx = 3002

In [None]:
x = fish_stacks0[idx][1:]

In [None]:
fish_metrics_df2[fish_metrics_df2["trench_idx"] == idx]

In [None]:
display_image(image.unstack_multichannel(x), scale=0.999)

In [None]:
y = x - x.min(axis=1)[:, np.newaxis, :, :]

In [None]:
display_image(image.unstack_multichannel(y))

In [None]:
plt.imshow(weighted_mean(y))

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
    )
)

In [None]:
ds = hv.Dataset(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value")

In [None]:
ds.to(hv.Violin, ["timepoint", "ground_truth"]).layout("channel").opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
        axiswise=True,
    )
).cols(1)

In [None]:
z = ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")

In [None]:
_stacked_violins = (
    ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")
)

hv.Layout([v.redim(value=k) for k, v in _stacked_violins.items()]).opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        # violin_width=3,
        inner=None,
        bandwidth=0.2,
        cut=0.05,
    )
).cols(1)

In [None]:
# fish_metrics_df2.groupby("channel").apply(lambda x: hv.Violin(x))

In [None]:
hv.Layout()

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(logy=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Dataset(df, ["ground_truth"], "value").to(
            hv.Distribution
        )
        # .overlay("ground_truth")
        # hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(show_legend=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): (
            hv.Distribution(df[df["ground_truth"]], "value", label="On")
            * hv.Distribution(df[~df["ground_truth"]], "value", label="Off")
        ).redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
).opts(hv.opts.Distribution(show_legend=True, bandwidth=0.3, cut=0.05))

In [None]:
from bokeh.sampledata.iris import flowers
from holoviews.operation import gridmatrix

iris_ds = hv.Dataset(flowers)

In [None]:
iris_ds

In [None]:
fish_metrics_df3 = fish_metrics_df.set_axis(
    ["_".join(c) for c in fish_metrics_df.columns], axis=1
)

In [None]:
fish_metrics_df3 = fish_metrics_df3.loc[
    :, [*fish_metrics_df3.columns[3:6], *fish_metrics_df3.columns[13:16]]
]

In [None]:
density_grid = gridmatrix(
    hv.Dataset(fish_metrics_df3), diagonal_type=hv.Distribution, chart_type=hv.Bivariate
)

In [None]:
density_grid