# Imports

In [None]:
import glob
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import deltalake
import distributed
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import scipy.signal
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
from dask.diagnostics import ProgressBar

ProgressBar().register()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import paulssonlab.image_analysis.mosaic as mosaic
import paulssonlab.image_analysis.delayed as delayed
import paulssonlab.image_analysis.drift as drift
import paulssonlab.image_analysis.geometry as geometry
import paulssonlab.image_analysis.image as image
import paulssonlab.image_analysis.pipeline as pipeline
import paulssonlab.image_analysis.readers as readers
import paulssonlab.image_analysis.segmentation.watershed as watershed
import paulssonlab.image_analysis.trench_detection as trench_detection
import paulssonlab.image_analysis.util as util
import paulssonlab.image_analysis.workflow as workflow
import paulssonlab.util.core as core
from paulssonlab.image_analysis.ui import RevImage, display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")
# hv.extension("matplotlib")

# Config

In [None]:
# nd2_filename = "/home/jqs1/scratch/microscopy/230915/230915_RBS_repressors.nd2"
nd2_filename = workflow.SplitFilename(
    sorted(
        glob.glob(
            # "/home/jqs1/scratch/jqs1/microscopy/230619/230619_NAO745_repressilators_split.nd2*"
            # "/home/jqs1/scratch/jqs1/microscopy/230707/230707_repressilators_restart.nd2.split.a*"
            "/home/jqs1/scratch/microscopy/230830/230830_repressilators.nd2.split.*"
        )
    )
)
assert nd2_filename.files
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/231101/231101_FP_calibration.nd2"
fish_filename = nd2_filename[0].parent / "FISH/real_run"

In [None]:
output_dir = Path(nd2_filename).parent / "test"
# output_dir.mkdir(exist_ok=True)

In [None]:
nd2 = workflow.get_nd2_reader(nd2_filename)
t_max = nd2.sizes["t"]

In [None]:
nd2.sizes

In [None]:
nd2.metadata["channels"]

In [None]:
colors = {
    "BF": "#ffffff",
    # "CFP-EM": "#6fb2e4",
    # "YFP-EM": "#eee461",
    # "RFP-EM": "#c66526",
    "CFP-EM": "#648FFF",
    "YFP-EM": "#FFB000",
    "RFP-EM": "#DC267F",
}

fish_colors = {
    "BF": "#ffffff",
    "GFP": "#f44336",
    "Cy5": "#03a9f4",
    # "Cy7": "#ffeb3b"
    "Cy7": "#8bc34a",
}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    memory="16GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

In [None]:
cluster.adapt(maximum=20)

# Pipeline

In [None]:
# k1 = 1e-9
# center = image.center_from_shape(nd2.get_frame_2D().shape) - np.array([0, -500])

In [None]:
# def get_frame_func(
#     filename, position, channel, t, k1=None, center=None, dark=None, flat=None
# ):
#     img = np.asarray(
#         workflow.get_nd2_frame(
#             filename, position=position, channel=channel, t=t, dark=dark, flat=flat
#         )
#     )
#     if k1 is not None:
#         img = image.correct_radial_distortion(img, k1=k1, center=center)
#     # TODO
#     img = img[:, 300 : img.shape[1] - 300]
#     return img


# def preprocess_func(img, k1=None, center=None, dark=None, flat=None):
def preprocess_func(img):
    return img[600 : img.shape[0] - 600, 1500 : img.shape[1] - 1500]

In [None]:
config = {
    # "composite_func": image.mean_composite,
    # "roi_detection_func": trench_detection.find_trenches,
    # "track_drift": True,
    # "segmentation_func": watershed.watershed_segment,
    "segmentation_channels": ["RFP-EM", "YFP-EM", "CFP-EM"],
    # "segmentation_channels": ["CFP-EM"],
    "trench_detection_channels": None,  # channel for trench detection, almost always same as segmentation_channel
    # "measure_channels": = ["RFP-PENTA", "YFP-DUAL"],
    "crop_channels": ["Phase-Fluor", "RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    "measure_channels": ["RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    # "fish_crop_channels": ["BF", GFP-EM", "Cy5", "Cy7"],
    "fish_measure_channels": ["GFP", "Cy5", "Cy7"],
    # "fish_probes": hhh,
    # "roi_detection_kwargs": {"width_to_pitch_ratio": 1.4 / 3.5},
    "roi_detection_kwargs": {"width_to_pitch_ratio": 1.8 / 3.5},  # TODO!!!
    "preprocess_func": preprocess_func,
    # "preprocess_kwargs": {"k1": k1, "center": center},
}

In [None]:
%%time
delayed_ = False

p = pipeline.DefaultPipeline(output_dir, config=config, delayed=delayed_)

for msg in readers.send_nd2(
    nd2_filename,
    # slices=dict(t=range(200, 202), v=range(8,10)),
    slices=dict(t=range(200, 202), v=[8]),
    delayed=delayed_,
):
    p.handle_message({**msg, "image_type": "science"})

In [None]:
%%time
for msg in readers.send_eaton_fish(
    fish_filename,
    slices=dict(t=None, v=range(8, 10)),
    delayed=delayed_,
):
    p.handle_message({**msg, "image_type": "fish_barcode"})

# FISH analysis

In [None]:
pd.read_parquet(output_dir / "fish_measurements")

In [None]:
x = pl.scan_parquet(output_dir / "fish_measurements/*/*")

In [None]:
x = pl.read_delta(str(output_dir / "fish_measurements"))

In [None]:
y = deltalake.DeltaTable(str(output_dir / "fish_measurements")).to_pandas()

In [None]:
dataset = ds.dataset(
    output_dir / "fish_measurements", format="parquet", partitioning="hive"
)
# df = dataset.to_table(filter=ds.field("position") == 14).to_pandas()
df = dataset.to_table().to_pandas().sort_values("t")

In [None]:
df

In [None]:
plt.scatter(df["mean"], df["otsu_mean"], s=1)

In [None]:
hv.Violin(df, ["channel", "t"], "otsu_mean").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(df, ["channel", "t"], "mean").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(df2, ["channel", "t", "call"], "otsu_mean").opts(
    hv.opts(
        width=700,
        show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("call"),
        violin_width=3,
        inner=None,
    )
)

In [None]:
p.fish_crops.keys()

In [None]:
def stack_fish_crops(crops, channels, timepoints):
    stacks = {}
    for (fov_num, channel, t), frame_crops in crops.items():
        if fov_num not in stacks:
            stacks[fov_num] = {}
        if channel not in channels:
            continue
        channel_idx = channels.index(channel)
        timepoint_idx = timepoints.index(t)
        for trench_idx, trench_slice in frame_crops.items():
            if trench_idx not in stacks[fov_num]:
                stacks[fov_num][trench_idx] = zarr.create(
                    (len(channels), len(timepoints), *trench_slice.shape),
                    dtype=trench_slice.dtype,
                    fill_value=0,
                )
            stacks[fov_num][trench_idx][channel_idx, timepoint_idx, :, :] = trench_slice
    return stacks

In [None]:
%%time

stacks = stack_fish_crops(p.fish_crops, ["BF", "GFP", "Cy5", "Cy7"], list(range(1, 11)))

In [None]:
stacks[8][2].shape

In [None]:
fish_thresholds = {"GFP": 1200, "Cy5": 1000, "Cy7": 800}

In [None]:
df2 = df.assign(call=df["otsu_mean"] > df["channel"].map(fish_thresholds))

In [None]:
df2[(df2["fov_num"] == 8) & (df2["roi"] == 40) & (df2["channel"] == "Cy7")]

In [None]:
display_image(image.unstack_multichannel(stacks[8][40]))  # , colors=fish_colors)

In [None]:
stacks[8][2][3].mean(axis=(1, 2))

In [None]:
display_image(stacks[8][2][3][2], scale=1)

In [None]:
x = stacks[8][2][3][2]
m = skimage.filters.threshold_otsu(x)
display_image(1 * (x > m), scale=1)

In [None]:
stacks[8][2][3][2].mean()

# Science analysis

In [None]:
p.fish_crops.keys()

In [None]:
display_image(p.fish_crops[(8, "Cy5", 1)][2], scale=1)

In [None]:
len(p._queue._items)

In [None]:
display_image(p.crops[(8, "RFP-EM", 200)][2], scale=True)

In [None]:
x = p.crops[(8, "RFP-EM", 200)][2]
y = watershed.segment(x)

In [None]:
display_image(skimage.color.label2rgb(y))

In [None]:
x = p.segmentation_masks[(8, 200)][2]

In [None]:
display_image(skimage.color.label2rgb(x))

In [None]:
p.measurements[(8, "RFP-EM", 200)][2]

In [None]:
p.fish_crops.keys()

In [None]:
len(p._queue._items)

In [None]:
q = p._queue._items

In [None]:
for k, v in p._queue._items.items():
    print("k", k)
    print(v.func)
    print(v.is_ready())
    print()
    print([x.__class__ for x in v.dependencies])
    print()

In [None]:
q[139884345499696].args

In [None]:
q[139884345499696].dependencies[1]

In [None]:
p.processed_frames[(8, "RFP-EM", 0)].shape

In [None]:
p.processed_frames.keys()

In [None]:
p.processed_frames[(8, "CFP-EM", 4)]

In [None]:
p.shifts[(8, 1)]

In [None]:
p.rois[(8, 0)]

In [None]:
p.crops[(8, "RFP-EM", 0)][1]

In [None]:
imgs = p.rois.delayed[(8, 0)].args

In [None]:
img_c = image.power_law_composite(*imgs)

In [None]:
display_image(img_c, scale=True, downsample=4)

In [None]:
%%time
diag = util.tree()
# diag = None
# img = p.rois[(8, 'RFP-EM', 0)]
# img = imgs[0]
trenches, info = trench_detection.find_trenches(
    img_c,  # [300:-300,900:-900],
    # angle=np.deg2rad(0.001),
    join_info=False,
    # width=12,
    # width_to_line_width_ratio=2,
    width_to_pitch_ratio=1.4 / 3.5,
    # width_to_pitch_ratio=None,
    # peak_func=trench_detection.peaks.find_peaks,
    diagnostics=diag,
)

In [None]:
info

In [None]:
trenches

In [None]:
diag.keys()

In [None]:
diag["end_finding"].keys()

In [None]:
skimage.filters.threshold_otsu(np.random.random(10))

In [None]:
diag["end_finding"]["reduced_profile"].Curve.I.data[:, 1]

In [None]:
diag["end_finding"]["reduced_profile"]

In [None]:
diag["end_finding"]["image_with_lines"]

In [None]:
regrid(diag["bboxes"])

In [None]:
p.raw_frames[(8, "RFP-EM", 0)]

In [None]:
p.measurements.keys()

In [None]:
for msg in readers.send_eaton_fish(
    fish_filename,
    slices=dict(t=None, v=[8]),
    delayed=True,
):

In [None]:
%%time
# %%pyinstrument
ts = np.arange(t_max)
# ts = np.arange(2)
res = []
for position in trange(13, 40):
    res.append(
        process_fov(
            partial(get_frame_func, filename),
            position,
            ts,
            output_dir / "central_crop2",
            segmentation_channel,
            measurement_channels,
            image_limits,
            find_trenches_kwargs=dict(
                angle=angle, pitch=pitch, width_to_pitch_ratio=2.2 / 3.5
            ),
            delayed=True,
        )
    )

In [None]:
%%time
futures = [client.compute(x) for x in tqdm(res)]

In [None]:
del futures

In [None]:
client.gather(futures)

In [None]:
futures

In [None]:
[fov for fov in futures if any(f.status == "error" for f in fov)]

In [None]:
errored = [e for fov in futures if (e := [f for f in fov if f.status == "error"])]

# Analysis

In [None]:
def trim_axes(ary, mask):
    for i in range(ary.ndim):
        axis_mask = np.any(mask, axis=tuple([j for j in range(mask.ndim) if j != i]))
        ary = ary[(*[slice(None)] * i, axis_mask)]
    return ary


# def norm(x, quantile=0.9, min_quantile=0.1, mask_value=None):
#     if mask_value is not None:
#         x = x[x != mask_value]
#     x = x - np.nanquantile(x, min_quantile)
#     if quantile is None or quantile == 0:
#         return x
#     elif quantile == 1:
#         return x / np.nanmax(x)
#     else:
#         if hasattr(x, "quantile"):
#             return x / x.quantile(quantile)
#         else:
#             return x / np.nanquantile(x, quantile)

# Segmentation

In [None]:
x = fish_stacks0[13][1:, :9]
# x = x - x.min(axis=1)[:,np.newaxis,:,:]

In [None]:
def weighted_mean(ary):
    ary = ary - ary.min(axis=1)[:, np.newaxis, :, :]
    # lmbda = (ary.max(axis=1) - ary.min(axis=1))[:,np.newaxis,:,:]
    lmbda = ary.max(axis=1)[:, np.newaxis, :, :]
    w = (
        1
        / 3
        * (lmbda / lmbda.sum(axis=(2, 3))[:, :, np.newaxis, np.newaxis]).sum(axis=0)[
            np.newaxis, :, :, :
        ]
    )
    if w.sum() == 0:
        return None
    return np.average(ary, axis=(2, 3), weights=np.broadcast_to(w, ary.shape))

In [None]:
fish_metrics = {
    idx: weighted_mean(np.asarray(stack[1:, ...]))
    for idx, stack in tqdm(fish_stacks0.items())
}

In [None]:
sum(1 for x in fish_metrics.values() if x is None)

In [None]:
fish_metrics[1][0]

In [None]:
bit_names = [(ch, str(t)) for ch in fish_channels[1:] for t in fish_timepoints]

# FISH analysis

In [None]:
fish_metrics_df = pd.DataFrame.from_dict(
    {
        trench_idx: ary.flatten()
        for trench_idx, ary in fish_metrics.items()
        if ary is not None
    },
    columns=pd.MultiIndex.from_tuples(bit_names, names=["channel", "timepoint"]),
    orient="index",
).rename_axis(index="trench_idx")
fish_metrics_df

In [None]:
fish_metrics_df2 = fish_metrics_df.melt(ignore_index=False).reset_index()

In [None]:
fish_metrics_df2

In [None]:
fish_thresholds = {"GFP": 0.007, "Cy5": 0.005, "Cy7": 0.002}

In [None]:
fish_metrics_df2["ground_truth"] = fish_metrics_df2["value"] > fish_metrics_df2[
    "channel"
].map(fish_thresholds)

In [None]:
(fish_metrics_df2.groupby("trench_idx").sum("ground_truth") == 0).sum()

In [None]:
fish_metrics_df2.groupby("channel").apply(lambda x: x["ground_truth"].sum() / len(x))

In [None]:
(fish_metrics_df2.groupby("channel").sum("ground_truth") == 0).sum()

In [None]:
idx = 200
# idx = 3002

In [None]:
x = fish_stacks0[idx][1:]

In [None]:
fish_metrics_df2[fish_metrics_df2["trench_idx"] == idx]

In [None]:
display_image(image.unstack_multichannel(x), scale=0.999)

In [None]:
y = x - x.min(axis=1)[:, np.newaxis, :, :]

In [None]:
display_image(image.unstack_multichannel(y))

In [None]:
plt.imshow(weighted_mean(y))

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
    )
)

In [None]:
ds = hv.Dataset(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value")

In [None]:
ds.to(hv.Violin, ["timepoint", "ground_truth"]).layout("channel").opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
        axiswise=True,
    )
).cols(1)

In [None]:
z = ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")

In [None]:
_stacked_violins = (
    ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")
)

hv.Layout([v.redim(value=k) for k, v in _stacked_violins.items()]).opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        # violin_width=3,
        inner=None,
        bandwidth=0.2,
        cut=0.05,
    )
).cols(1)

In [None]:
# fish_metrics_df2.groupby("channel").apply(lambda x: hv.Violin(x))

In [None]:
hv.Layout()

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(logy=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Dataset(df, ["ground_truth"], "value").to(
            hv.Distribution
        )
        # .overlay("ground_truth")
        # hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(show_legend=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): (
            hv.Distribution(df[df["ground_truth"]], "value", label="On")
            * hv.Distribution(df[~df["ground_truth"]], "value", label="Off")
        ).redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
).opts(hv.opts.Distribution(show_legend=True, bandwidth=0.3, cut=0.05))

In [None]:
from bokeh.sampledata.iris import flowers
from holoviews.operation import gridmatrix

iris_ds = hv.Dataset(flowers)

In [None]:
iris_ds

In [None]:
fish_metrics_df3 = fish_metrics_df.set_axis(
    ["_".join(c) for c in fish_metrics_df.columns], axis=1
)

In [None]:
fish_metrics_df3 = fish_metrics_df3.loc[
    :, [*fish_metrics_df3.columns[3:6], *fish_metrics_df3.columns[13:16]]
]

In [None]:
density_grid = gridmatrix(
    hv.Dataset(fish_metrics_df3), diagonal_type=hv.Distribution, chart_type=hv.Bivariate
)

In [None]:
density_grid