# Imports

In [None]:
import io
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from glob import glob
from pathlib import Path

import dask
import deltalake
import distributed
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import scipy.signal
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
pl.enable_string_cache()

In [None]:
from dask.diagnostics import ProgressBar

ProgressBar().register()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import paulssonlab.image_analysis.mosaic as mosaic
import paulssonlab.image_analysis.delayed as delayed
import paulssonlab.image_analysis.drift as drift
import paulssonlab.image_analysis.geometry as geometry
import paulssonlab.image_analysis.image as image
import paulssonlab.image_analysis.pipeline as pipeline
import paulssonlab.image_analysis.readers as readers
import paulssonlab.image_analysis.segmentation.watershed as watershed
import paulssonlab.image_analysis.trench_detection as trench_detection
import paulssonlab.image_analysis.util as util
import paulssonlab.image_analysis.workflow as workflow
import paulssonlab.util.core as core
import paulssonlab.util.numeric as numeric
from paulssonlab.image_analysis.ui import RevImage, display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")
# hv.extension("matplotlib")

# Functions

In [None]:
def concat_glob(filename):
    return pl.concat([pl.scan_ipc(f) for f in glob(filename)], how="diagonal")

In [None]:
def label_columns(cols, func=None):
    expr = None
    for col in cols:
        if expr is None:
            expr = pl.when(pl.col(col).is_not_null())
        else:
            expr = expr.when(pl.col(col).is_not_null())
        if func is not None:
            lit = func(col)
        else:
            lit = col
        expr = expr.then(pl.lit(lit))
    return expr

# Config

In [None]:
cycle_channel_to_bit_num = """cycle	channel	bit
1	AF555	1
2	AF555	2
1	Cy5	3
2	Cy5	4
1	Alexa750	5
2	Alexa750	6
3	AF555	7
4	AF555	8
3	Cy5	9
4	Cy5	10
3	Alexa750	11
4	Alexa750	12
5	AF555	13
6	AF555	14
5	Cy5	15
6	Cy5	16
5	Alexa750	17
6	Alexa750	18
7	AF555	19
8	AF555	20
7	Cy5	21
8	Cy5	22
7	Alexa750	23
8	Alexa750	24
9	AF555	25
10	AF555	26
9	Cy5	27
10	Cy5	28
9	Alexa750	29
10	Alexa750	30"""

cycle_channel_to_bit_num = pl.read_csv(
    io.StringIO(cycle_channel_to_bit_num), separator="\t"
).with_columns(channel=pl.col("channel").replace({"AF555": "GFP", "Alexa750": "Cy7"}))

In [None]:
nd2_filename = Path("/home/jqs1/scratch/microscopy/230915/230915_RBS_repressors.nd2")
# nd2_filename = Path("/home/jqs1/scratch/microscopy/230912/230912_bcd_rbses001.nd2")
# nd2_filename = workflow.SplitFilename(
#     sorted(
#         glob.glob(
#             # "/home/jqs1/scratch/jqs1/microscopy/230707/230707_repressilators_restart.nd2.split.a*"
#             "/home/jqs1/scratch/microscopy/230830/230830_repressilators.nd2.split.*"
#         )
#     )
# )
# assert nd2_filename.files
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/231101/231101_FP_calibration.nd2"
# fish_filename = nd2_filename[0].parent / "FISH/real_run"

In [None]:
if isinstance(nd2_filename, workflow.SplitFilename):
    parent_dir = nd2_filename[0].parent
else:
    parent_dir = nd2_filename.parent
fish_filename = parent_dir / "FISH/real_run"
output_dir = parent_dir / "test2"
# output_dir.mkdir(exist_ok=True)

In [None]:
nd2 = workflow.get_nd2_reader(nd2_filename)
t_max = nd2.sizes["t"]

In [None]:
nd2.sizes

In [None]:
nd2.metadata["channels"]

In [None]:
colors = {
    "BF": "#ffffff",
    # "CFP-EM": "#6fb2e4",
    # "YFP-EM": "#eee461",
    # "RFP-EM": "#c66526",
    "CFP-EM": "#648FFF",
    "YFP-EM": "#FFB000",
    "RFP-EM": "#DC267F",
}

fish_colors = {
    "BF": "#ffffff",
    "GFP": "#f44336",
    "Cy5": "#03a9f4",
    # "Cy7": "#ffeb3b"
    "Cy7": "#8bc34a",
}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    memory="16GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

In [None]:
cluster.adapt(maximum=20)

# Pipeline

In [None]:
# k1 = 1e-9
# center = image.center_from_shape(nd2.get_frame_2D().shape) - np.array([0, -500])

In [None]:
# def get_frame_func(
#     filename, position, channel, t, k1=None, center=None, dark=None, flat=None
# ):
#     img = np.asarray(
#         workflow.get_nd2_frame(
#             filename, position=position, channel=channel, t=t, dark=dark, flat=flat
#         )
#     )
#     if k1 is not None:
#         img = image.correct_radial_distortion(img, k1=k1, center=center)
#     # TODO
#     img = img[:, 300 : img.shape[1] - 300]
#     return img


# def preprocess_func(img, k1=None, center=None, dark=None, flat=None):
def preprocess_func(img):
    return img[600 : img.shape[0] - 600, 1500 : img.shape[1] - 1500]

In [None]:
config = {
    # "composite_func": image.mean_composite,
    # "roi_detection_func": trench_detection.find_trenches,
    # "track_drift": True,
    # "segmentation_func": watershed.watershed_segment,
    # "segmentation_channels": ["RFP-EM", "YFP-EM", "CFP-EM"],
    "segmentation_channels": ["RFP-EM"],
    "trench_detection_channels": None,  # channel for trench detection, almost always same as segmentation_channel
    # "measure_channels": = ["RFP-PENTA", "YFP-DUAL"],
    # "crop_channels": ["Phase-Fluor", "RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    "crop_channels": ["RFP-EM", "YFP-EM"],
    # "measure_channels": ["RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    "measure_channels": ["RFP-EM", "YFP-EM"],
    # "fish_crop_channels": ["BF", GFP-EM", "Cy5", "Cy7"],
    "fish_measure_channels": ["GFP", "Cy5", "Cy7"],
    "fish_drift_tracking_channel": "BF",
    # "fish_probes": hhh,
    # "roi_detection_kwargs": {"width_to_pitch_ratio": 1.4 / 3.5},
    "roi_detection_kwargs": {"width_to_pitch_ratio": 3 / 3.5},  # TODO!!!
    "preprocess_func": preprocess_func,
    # "preprocess_kwargs": {"k1": k1, "center": center},
}

In [None]:
%%time
delayed_ = False

p = pipeline.DefaultPipeline(output_dir, config=config, delayed=delayed_)

for msg in readers.send_nd2(
    nd2_filename,
    # slices=dict(t=range(200, 202), v=range(8,10)),
    slices=dict(t=range(30, 62), v=[8]),
    delayed=delayed_,
):
    p.handle_message({**msg, "image_type": "science"})

In [None]:
%%time
for msg in readers.send_eaton_fish(
    fish_filename,
    # slices=dict(t=None, v=range(8, 10)),
    slices=dict(t=None, v=[8]),
    delayed=delayed_,
):
    p.handle_message({**msg, "image_type": "fish_barcode"})

In [None]:
p.handle_message({"type": "done"})

In [None]:
display_image(
    image.unstack_multichannel(fish_crops[1:, 8, :, 601].squeeze().swapaxes(0, 1)),
    scale=1,
)  # , colors=fish_colors)

# FISH analysis

In [None]:
fish_crops = readers.ZarrSlicer(
    output_dir / "fish_crops",
    r"fov=(?P<v>\d+)/channel=(?P<c>[^/]+)/t=(?P<t>\d+)",
    files=False,
    recursive=True,
)

In [None]:
# x = pl.scan_parquet(output_dir / "fish_measurements/*/*")
# x = pl.read_delta(str(output_dir / "fish_measurements"))
# y = deltalake.DeltaTable(str(output_dir / "fish_measurements")).to_pandas()
dataset = ds.dataset(
    output_dir / "fish_measurements", format="parquet", partitioning="hive"
)
# df = dataset.to_table(filter=ds.field("position") == 14).to_pandas()
# df = dataset.to_table().to_pandas().sort_values("t")
fish = pl.scan_pyarrow_dataset(dataset).sort(["fov_num", "roi", "t"]).collect()
fish = fish.join(
    cycle_channel_to_bit_num,
    left_on=["t", "channel"],
    right_on=[pl.col("cycle").cast(pl.Int32), "channel"],
)

In [None]:
plt.scatter(df["mean"], df["otsu_mean"], s=1)

In [None]:
hv.Violin(df, ["channel", "t"], "otsu_mean").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(df, ["channel", "t"], "mean").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(df2, ["channel", "t", "call"], "otsu_mean").opts(
    hv.opts(
        width=700,
        show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("call"),
        violin_width=3,
        inner=None,
    )
)

In [None]:
p.fish_crops.keys()

In [None]:
fish_thresholds = {"GFP": 1200, "Cy5": 1000, "Cy7": 800}

In [None]:
fish2 = fish.with_columns(
    threshold=pl.col("channel").replace(fish_thresholds, return_dtype=pl.Float64)
).with_columns(call=pl.col("otsu_mean") > pl.col("threshold"))

In [None]:
fish2

In [None]:
fish2.filter(pl.col("roi") == 19, pl.col("channel") == "Cy5")

In [None]:
fish_barcodes.filter(pl.col("roi") == 19)["fish_barcode"].to_numpy()

In [None]:
display_image(
    image.unstack_multichannel(
        image.crop_to_mask(
            fish_crops[:, 8, ["BF", "GFP", "Cy5", "Cy7"], 19].squeeze().swapaxes(0, 1)
        )
    ),
    scale=1,
)  # , colors=fish_colors)

In [None]:
seq_barcodes[595495, "seq_barcode"].to_numpy()

In [None]:
seq_barcodes[595495, "BC:BIT4|seq"]

In [None]:
seq_barcodes[595496, "BC:BIT4|seq"]

In [None]:
seq_barcodes[595497, "BC:BIT4|seq"]

In [None]:
seq_barcodes[595502, "BC:BIT4|seq"]

In [None]:
with pl.Config() as cfg:
    cfg.set_tbl_rows(100)
    display(fish2.filter(pl.col("roi") == 19))

In [None]:
fish_barcodes = (
    fish2.select(["fov_num", "roi", "bit", "call"])
    .group_by(["fov_num", "roi"])
    .agg(fish_barcode=pl.col("call").sort_by("bit"))
    .with_columns(num_bits=pl.col("fish_barcode").list.len())
    .filter(pl.col("num_bits") == 30)
    .select(pl.all().exclude("num_bits"))
    .with_columns(
        fish_barcode=pl.col("fish_barcode").cast(pl.List(pl.Int32))
        # .cast(pl.Array(pl.Boolean, 30)))
    )
)

## Join with codebook

In [None]:
%%time
# arrow_filename = "/home/jqs1/scratch/sequencing/230818_repressilators/20230905_1132_1H_PAQ85679_c9d74ddb/extract_segments/*.arrow"
arrow_filename = "/home/jqs1/scratch/sequencing/231201_bcd_rbses_run3/20231201_1101_1F_PAU05823_773c75ee/extract_segments/*.arrow"
seq = concat_glob(arrow_filename).collect()

In [None]:
%%time
seq2 = seq.with_columns(
    dup=pl.col("name").is_duplicated(),
    e2e=pl.col("variants_path")
    .list.set_intersection(
        [
            "<BC:UPSTREAM",
            "<UNS3",
            ">BC:UPSTREAM",
            ">UNS3",
        ]
    )
    .list.len()
    == 2,
    bc_e2e=pl.col("variants_path")
    .list.set_intersection(
        [
            "<BC:UPSTREAM",
            "<BC:SPACER2",
            ">BC:UPSTREAM",
            ">BC:SPACER2",
        ]
    )
    .list.len()
    == 2,
    bc_errors=pl.sum_horizontal(r"^BC:BIT\d+\|(insertions|deletions|mismatches)$"),
)

In [None]:
seq2.filter(pl.col("bc_errors") == 0).count()

In [None]:
plt.hist(seq2["bc_errors"].to_numpy(), bins=100, log=True);

In [None]:
seq3 = seq2.filter(pl.col("e2e"), ~pl.col("dup"), pl.col("bc_errors") == 0)

In [None]:
%%time
seq3 = seq2.filter(pl.col("e2e"), ~pl.col("dup")).with_columns(
    pl.coalesce(
        label_columns(
            [
                "pLIB433:PhlF_pPhlF|seq",
                "pLIB434:LacI_pTac|seq",
                "pLIB435:BetI_pBetI|seq",
            ],
            lambda x: x.split("|")[0],
        ),
        pl.concat_str(pl.lit("pLIB431-432:RBS="), pl.col("pLIB431-432:RBS|variant")),
    ).alias("RBS")
)

In [None]:
seq_barcodes = seq3.with_columns(
    seq_barcode=pl.concat_list(
        [f"BC:BIT{idx}|variant" for idx in range(30)]
    )  # .cast(pl.Array(pl.Boolean, 30))
)

In [None]:
seq_barcodes["grouping_path"][0].to_list()

In [None]:
seq_barcodes["seq_barcode"][0].to_list()

In [None]:
seq_barcodes.group_by("grouping_path").agg(pl.len())["len"].value_counts()

In [None]:
seq_barcodes.group_by("seq_barcode").agg(pl.len())["len"].value_counts()

In [None]:
x = (
    seq_barcodes.group_by("seq_barcode")
    .agg(pl.len(), pl.col("name"), pl.col("grouping_path"), pl.col("variants_path"))
    .filter(pl.col("len") == 2)
)

In [None]:
y = x[0].explode("name", "grouping_path", "variants_path")

In [None]:
y

In [None]:
y[0, "seq_barcode"].to_list()

In [None]:
y[0, "grouping_path"].to_list()

In [None]:
y[0, "name"]

In [None]:
seq3.filter(
    pl.col("name")
    == "consensus_f451f9bee14d46571dececc8227a9769bc5aabec0862f9aea5918d2ed9562401"
)[0, "consensus_seq"]

In [None]:
z = (
    seq3.filter(
        pl.col("name")
        == "consensus_f451f9bee14d46571dececc8227a9769bc5aabec0862f9aea5918d2ed9562401"
    )[0]
    .to_pandas()
    .T
)

In [None]:
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(z)

In [None]:
y[0, "name"]

In [None]:
y[0, "variants_path"].to_list()

In [None]:
fish_barcodes.group_by("fish_barcode").agg(pl.len())["len"].value_counts()

In [None]:
fish_barcodes.with_columns(fish_hash=pl.col("fish_barcode").hash())

In [None]:
fish_barcodes.group_by("fish_barcode").agg(pl.len())

In [None]:
seq_barcodes.with_columns(seq_barcode=pl.col("seq_barcode").hash()).join(
    fish_barcodes.with_columns(fish_barcode=pl.col("fish_barcode").hash()),
    left_on="seq_barcode",
    right_on="fish_barcode",
    how="full",
).select("grouping_path", "fish_barcode").filter(
    pl.col("grouping_path").is_not_null(), pl.col("fish_barcode").is_not_null()
)

In [None]:
seq_barcodes.join(
    fish_barcodes, left_on="seq_barcode", right_on="fish_barcode", how="full"
)

In [None]:
seq_variants.head()

In [None]:
fb = (
    fish_barcodes.sort(["fov_num", "roi"])["fish_barcode"]
    .cast(pl.Array(pl.Boolean, 30))
    .to_numpy()
)
sb = seq_barcodes["seq_barcode"].cast(pl.Array(pl.Boolean, 30)).to_numpy()

In [None]:
dists = np.count_nonzero(fb[503, np.newaxis] != sb, axis=1)
plt.hist(dists, bins=20, log=True);

In [None]:
np.argmin(dists)

In [None]:
dists.sort()
dists

In [None]:
fb[1]

In [None]:
[np.count_nonzero(fb[idx, np.newaxis] != sb, axis=1).min() for idx in trange(50)]

In [None]:
fish_barcodes.filter(pl.col("roi") == 18)["fish_barcode"][0].to_numpy()

In [None]:
fb[18]

In [None]:
sb.shape

In [None]:
(fb[18, np.newaxis] != sb).shape

In [None]:
np.count_nonzero(fb[18, np.newaxis] != sb, axis=1)

In [None]:
np.count_nonzero(fb[18, np.newaxis] != sb, axis=1).min()

In [None]:
sb[:10]

# Science analysis

In [None]:
crops = readers.ZarrSlicer(
    output_dir / "crops",
    r"fov=(?P<v>\d+)/channel=(?P<c>[^/]+)/t=(?P<t>\d+)",
    files=False,
    recursive=True,
)

In [None]:
display_image(
    image.unstack(image.crop_to_mask(crops[:, 8, "YFP-EM", 33].squeeze())), scale=1
)

In [None]:
crops[:, 8, "YFP-EM", 3].squeeze().shape

In [None]:
display_image(
    image.unstack_multichannel(crops[:, 8, :, 601].squeeze().swapaxes(0, 1)),
    scale=1,
)

# Old FISH analysis

In [None]:
fish_metrics_df = pd.DataFrame.from_dict(
    {
        trench_idx: ary.flatten()
        for trench_idx, ary in fish_metrics.items()
        if ary is not None
    },
    columns=pd.MultiIndex.from_tuples(bit_names, names=["channel", "timepoint"]),
    orient="index",
).rename_axis(index="trench_idx")
fish_metrics_df

In [None]:
fish_metrics_df2 = fish_metrics_df.melt(ignore_index=False).reset_index()

In [None]:
fish_metrics_df2

In [None]:
fish_thresholds = {"GFP": 0.007, "Cy5": 0.005, "Cy7": 0.002}

In [None]:
fish_metrics_df2["ground_truth"] = fish_metrics_df2["value"] > fish_metrics_df2[
    "channel"
].map(fish_thresholds)

In [None]:
(fish_metrics_df2.groupby("trench_idx").sum("ground_truth") == 0).sum()

In [None]:
fish_metrics_df2.groupby("channel").apply(lambda x: x["ground_truth"].sum() / len(x))

In [None]:
(fish_metrics_df2.groupby("channel").sum("ground_truth") == 0).sum()

In [None]:
idx = 200
# idx = 3002

In [None]:
x = fish_stacks0[idx][1:]

In [None]:
fish_metrics_df2[fish_metrics_df2["trench_idx"] == idx]

In [None]:
display_image(image.unstack_multichannel(x), scale=0.999)

In [None]:
y = x - x.min(axis=1)[:, np.newaxis, :, :]

In [None]:
display_image(image.unstack_multichannel(y))

In [None]:
plt.imshow(weighted_mean(y))

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        violin_color=hv.dim("channel").str(),
        inner=None,
        # violin_width=1,
    )
)

In [None]:
hv.Violin(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value").opts(
    hv.opts(
        width=700,
        show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
    )
)

In [None]:
ds = hv.Dataset(fish_metrics_df2, ["channel", "timepoint", "ground_truth"], "value")

In [None]:
ds.to(hv.Violin, ["timepoint", "ground_truth"]).layout("channel").opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        split=hv.dim("ground_truth"),
        violin_width=3,
        inner=None,
        axiswise=True,
    )
).cols(1)

In [None]:
z = ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")

In [None]:
_stacked_violins = (
    ds.to(hv.Violin, ["timepoint"]).overlay("ground_truth").layout("channel")
)

hv.Layout([v.redim(value=k) for k, v in _stacked_violins.items()]).opts(
    hv.opts.Violin(
        width=700,
        # show_legend=True,
        # violin_color=hv.dim("channel").str(),
        # violin_width=3,
        inner=None,
        bandwidth=0.2,
        cut=0.05,
    )
).cols(1)

In [None]:
# fish_metrics_df2.groupby("channel").apply(lambda x: hv.Violin(x))

In [None]:
hv.Layout()

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(logy=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): hv.Dataset(df, ["ground_truth"], "value").to(
            hv.Distribution
        )
        # .overlay("ground_truth")
        # hv.Distribution(df, "value").redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
)  # .opts(hv.opts.Distribution(show_legend=True))

In [None]:
hv.GridSpace(
    {
        (timepoint, channel): (
            hv.Distribution(df[df["ground_truth"]], "value", label="On")
            * hv.Distribution(df[~df["ground_truth"]], "value", label="Off")
        ).redim(value=channel)
        for (timepoint, channel), df in fish_metrics_df2.groupby(
            ["timepoint", "channel"]
        )
    },
    kdims=["timepoint", "channel"],
).opts(hv.opts.Distribution(show_legend=True, bandwidth=0.3, cut=0.05))

In [None]:
from bokeh.sampledata.iris import flowers
from holoviews.operation import gridmatrix

iris_ds = hv.Dataset(flowers)

In [None]:
iris_ds

In [None]:
fish_metrics_df3 = fish_metrics_df.set_axis(
    ["_".join(c) for c in fish_metrics_df.columns], axis=1
)

In [None]:
fish_metrics_df3 = fish_metrics_df3.loc[
    :, [*fish_metrics_df3.columns[3:6], *fish_metrics_df3.columns[13:16]]
]

In [None]:
density_grid = gridmatrix(
    hv.Dataset(fish_metrics_df3), diagonal_type=hv.Distribution, chart_type=hv.Bivariate
)

In [None]:
density_grid