# Imports

In [None]:
import io
import itertools as it
import os
import re
from collections import Counter, namedtuple
from functools import partial
from glob import glob
from pathlib import Path

import dask
import dask.distributed
import deltalake
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import scipy.signal
import skimage.measure
import zarr
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
pl.enable_string_cache()

In [None]:
from dask.diagnostics import ProgressBar

ProgressBar().register()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import paulssonlab.image_analysis.mosaic as mosaic
import paulssonlab.image_analysis.delayed as delayed
import paulssonlab.image_analysis.drift as drift
import paulssonlab.image_analysis.geometry as geometry
import paulssonlab.image_analysis.image as image
import paulssonlab.image_analysis.pipeline as pipeline
import paulssonlab.image_analysis.readers as readers
import paulssonlab.image_analysis.segmentation.watershed as watershed
import paulssonlab.image_analysis.trench_detection as trench_detection
import paulssonlab.image_analysis.util as util
import paulssonlab.image_analysis.workflow as workflow
import paulssonlab.util.core as core
import paulssonlab.util.numeric as numeric
from paulssonlab.image_analysis.ui import RevImage, display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")
# hv.extension("matplotlib")

# Functions

In [None]:
def concat_glob(filename):
    return pl.concat([pl.scan_ipc(f) for f in glob(filename)], how="diagonal")

In [None]:
def label_columns(cols, func=None):
    expr = None
    for col in cols:
        if expr is None:
            expr = pl.when(pl.col(col).is_not_null())
        else:
            expr = expr.when(pl.col(col).is_not_null())
        if func is not None:
            lit = func(col)
        else:
            lit = col
        expr = expr.then(pl.lit(lit))
    return expr

# Config

In [None]:
cycle_channel_to_bit_num = """cycle	channel	bit
1	AF555	1
2	AF555	2
1	Cy5	3
2	Cy5	4
1	Alexa750	5
2	Alexa750	6
3	AF555	7
4	AF555	8
3	Cy5	9
4	Cy5	10
3	Alexa750	11
4	Alexa750	12
5	AF555	13
6	AF555	14
5	Cy5	15
6	Cy5	16
5	Alexa750	17
6	Alexa750	18
7	AF555	19
8	AF555	20
7	Cy5	21
8	Cy5	22
7	Alexa750	23
8	Alexa750	24
9	AF555	25
10	AF555	26
9	Cy5	27
10	Cy5	28
9	Alexa750	29
10	Alexa750	30"""

cycle_channel_to_bit_num = pl.read_csv(
    io.StringIO(cycle_channel_to_bit_num), separator="\t"
).with_columns(
    channel=pl.col("channel").replace({"AF555": "GFP", "Alexa750": "Cy7"}),
    bit=pl.col("bit") - 1,
)
# use zero-indexing for bit, one-indexing for cycle
# (eventually microscope acquisition should start with cycle 0)

In [None]:
nd2_filename = Path("/home/jqs1/scratch/microscopy/240627/LIB533.nd2")
# nd2_filename = Path("/home/jqs1/scratch/microscopy/240612/LIB533_isolates_restart.nd2")
# nd2_filename = Path("/home/jqs1/scratch/microscopy/230915/230915_RBS_repressors.nd2")
# nd2_filename = Path("/home/jqs1/scratch/microscopy/230912/230912_bcd_rbses001.nd2")
# nd2_filename = Path("/home/jqs1/scratch/microscopy/231101/231101_FP_calibration.nd2")

# nd2_filename = workflow.SplitFilename(
#     sorted(
#         glob(
#             # "/home/jqs1/scratch/microscopy/230707/230707_repressilators_restart.nd2.split.a*"
#             "/home/jqs1/scratch/microscopy/230830/230830_repressilators.nd2.split.*"
#         )
#     )
# )
# assert nd2_filename.files

In [None]:
if isinstance(nd2_filename, workflow.SplitFilename):
    parent_dir = nd2_filename[0].parent
else:
    parent_dir = nd2_filename.parent
fish_filename = parent_dir / "FISH/real_run"
output_dir = parent_dir / "zarr_test"
# output_dir.mkdir(exist_ok=True)

In [None]:
nd2 = workflow.get_nd2_reader(nd2_filename)
t_max = nd2.sizes["t"]

In [None]:
nd2.sizes

In [None]:
nd2.metadata["channels"]

In [None]:
colors = {
    "BF": "#ffffff",
    # "CFP-EM": "#6fb2e4",
    # "YFP-EM": "#eee461",
    # "RFP-EM": "#c66526",
    "CFP-EM": "#648FFF",
    "YFP-EM": "#FFB000",
    "RFP-EM": "#DC267F",
}

fish_colors = {
    "BF": "#ffffff",
    "GFP": "#f44336",
    "Cy5": "#03a9f4",
    # "Cy7": "#ffeb3b"
    "Cy7": "#8bc34a",
}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="6:00:00",
    memory="4GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

In [None]:
cluster.adapt(maximum=100)

# ROI test

In [None]:
%%time
img = nd2.get_frame_2D(t=0, v=8, c=0)
img = img[600 : img.shape[0] - 600, 1500 : img.shape[1] - 1500]
rois, info = trench_detection.find_trenches(img, join_info=False, width=12)

In [None]:
rois

In [None]:
rois2 = geometry.filter_and_shift_rois(
    rois, (0, 0), geometry.get_image_limits(img.shape)
)

In [None]:
crops = pipeline.crop_rois(img, rois2)

In [None]:
crops.keys()

In [None]:
shapes = Counter()
for crop in crops.values():
    shapes[tuple(crop.shape)] += 1

In [None]:
shapes

In [None]:
shapes

In [None]:
info

# Zarr writing test

In [None]:
def dummy_value(shape, val):
    if shape is None:
        return val
    else:
        return np.full(shape, val)

In [None]:
dummy_value((1,), 3)

In [None]:
q = delayed.DelayedQueue()
s = delayed.DelayedBatchedZarrStore(
    q, output_dir / "test_store/t={1}", write_options=dict(chunks=(2, 1))
)
shape = (5, 3)
s[1, 0, 0] = dummy_value(shape, 12)
s[0, 0, 0] = dummy_value(shape, 11)
s[2, 0, 0] = dummy_value(shape, 22)
s[2, 1, 0] = dummy_value(shape, 88)
s[3, 1, 0] = dummy_value(shape, 99)

In [None]:
q = delayed.DelayedQueue()
s = delayed.DelayedBatchedZarrStore(
    q, output_dir / "test_store/t={1}/roi={3}", write_options=dict(chunks=(2, 1, 7))
)
shape = (5, 3)
s[1, 0, 0] = {2: dummy_value(shape, 12)}
s[0, 0, 0] = {2: dummy_value(shape, 11)}
s[2, 0, 0] = {2: dummy_value(shape, 22)}
s[2, 1, 0] = {2: dummy_value(shape, 88)}
s[3, 1, 0] = {2: dummy_value(shape, 99)}

In [None]:
q = delayed.DelayedQueue()
s = delayed.DelayedBatchedZarrStore(
    q, output_dir / "test_store/t={1}/roi={3}", write_options=dict(chunks=(2, 1, 7, 7))
)
shape = (5, 3)
s[1, 0, 0] = {(3, 2): dummy_value(shape, 12)}
s[0, 0, 0] = {(3, 2): dummy_value(shape, 11)}
s[2, 0, 0] = {(3, 2): dummy_value(shape, 22)}
s[2, 1, 0] = {(3, 2): dummy_value(shape, 88)}
s[3, 1, 0] = {(3, 2): dummy_value(shape, 99)}

In [None]:
q = delayed.DelayedQueue()
s = delayed.DelayedBatchedZarrStore(
    q, output_dir / "test_store/t={}", write_options=dict(chunks=(2, 1))
)
shape = (5, 3)
# shape = None
s[0, 1, 0] = dummy_value(shape, 11)
s[0, 0, 0] = dummy_value(shape, 11)
s[0, 2, 0] = dummy_value(shape, 22)
s[1, 2, 0] = dummy_value(shape, 88)
s[1, 3, 0] = dummy_value(shape, 99)

In [None]:
s.write()

In [None]:
q.poll()

In [None]:
s.writers

In [None]:
s._write_queue

In [None]:
group_by_chunks(idxs, chunks, shape)

In [None]:
x = np.empty((2, 2), dtype=object)
x

In [None]:
y = {
    (1, 1): np.full((5, 2), 44),
    (1, 0): np.full((5, 2), 22),
    (0, 1): np.full((5, 2), 33),
    (0, 0): np.full((5, 2), 11),
}

In [None]:
for k, v in y.items():
    x[k] = v

In [None]:
z = np.stack([x[k] for k in sorted(y.keys())]).reshape(2, 2, 5, 2)

In [None]:
z[1, 0].shape

In [None]:
stack_chunks(
    {
        (1, 1): np.full((5, 2), 44),
        (1, 0): np.full((5, 2), 22),
        (0, 1): np.full((5, 2), 33),
        (0, 0): np.full((5, 2), 11),
    }
)

In [None]:
a = np.zeros(shape)
a[:3, :2, :7] = 1

In [None]:
a[3:, :2, :7]

In [None]:
indices_for_chunk((2, 2), chunks)

In [None]:
chunks = (3, 1, 2, 5, None, None)

In [None]:
chunks = (3, 2, None)
shape = (4, 3, 7)

In [None]:
idxs = {
    (0, 0): "a",
    (0, 1): "a",
    (0, 2): "b",
    (1, 0): "a",
    (1, 1): "a",
    (1, 2): "b",
    (2, 0): "a",
    (3, 0): "c",
}

In [None]:
arr[1, 0, :] = np.arange(3)

# Pipeline

In [None]:
# k1 = 1e-9
# center = image.center_from_shape(nd2.get_frame_2D().shape) - np.array([0, -500])

In [None]:
# def get_frame_func(
#     filename, position, channel, t, k1=None, center=None, dark=None, flat=None
# ):
#     img = np.asarray(
#         workflow.get_nd2_frame(
#             filename, position=position, channel=channel, t=t, dark=dark, flat=flat
#         )
#     )
#     if k1 is not None:
#         img = image.correct_radial_distortion(img, k1=k1, center=center)
#     # TODO
#     img = img[:, 300 : img.shape[1] - 300]
#     return img


# def preprocess_func(img, k1=None, center=None, dark=None, flat=None):
def preprocess_func(img):
    return img[600 : img.shape[0] - 600, 1500 : img.shape[1] - 1500]

In [None]:
config = {
    # "composite_func": image.mean_composite,
    # "roi_detection_func": trench_detection.find_trenches,
    # "track_drift": True,
    # "segmentation_func": watershed.watershed_segment,
    # "segmentation_channels": ["RFP-EM", "YFP-EM", "CFP-EM"],
    "trench_detection_channels": None,  # channel for trench detection, almost always same as segmentation_channel
    # "measure_channels": = ["RFP-PENTA", "YFP-DUAL"],
    # "crop_channels": ["Phase-Fluor", "RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    # "measure_channels": ["RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    # 230912/230915
    # "segmentation_channels": ["RFP-EM"],
    # "crop_channels": ["RFP-EM", "YFP-EM"],
    # "measure_channels": ["RFP-EM", "YFP-EM"],
    # 230818/230830
    "segmentation_channels": ["RFP-EM", "YFP-EM", "CFP-EM"],
    "crop_channels": ["RFP-EM", "YFP-EM", "CFP-EM"],
    "measure_channels": ["RFP-EM", "YFP-EM", "CFP-EM"],
    # 231101
    # "segmentation_channels": ["RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    # "crop_channels": ["Phase-Fluor", "RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    # "measure_channels": ["RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"],
    # "fish_crop_channels": ["BF", GFP-EM", "Cy5", "Cy7"],
    # 240612
    # "segmentation_channels": ["CFP-EM"],
    # "crop_channels": ["CFP-EM", "YFP-EM"],
    # "measure_channels": ["CFP-EM", "YFP-EM"],
    ###
    "fish_measure_channels": ["GFP", "Cy5", "Cy7"],
    "fish_drift_tracking_channel": "BF",
    # "fish_probes": hhh,
    # "roi_detection_kwargs": {"width_to_pitch_ratio": 1.4 / 3.5},
    "roi_detection_kwargs": {"width_to_pitch_ratio": 3 / 3.5},  # TODO!!!
    "preprocess_func": preprocess_func,
    # "preprocess_kwargs": {"k1": k1, "center": center},
}

In [None]:
%%time
# delayed_ = False
delayed_ = client

p = pipeline.DefaultPipeline(output_dir, config=config, delayed=delayed_)

for msg in readers.send_nd2(
    nd2_filename,
    # slices=dict(v=[8], t=range(190, 192)),
    slices=dict(v=[8]),
    # slices=dict(t=[80,85,90,91], v=[8]),
    # slices=dict(t=range(88,92), v=[8]),
    # slices=dict(t=[60,61], v=[8]),
    # slices=dict(t=[60], v=range(8,18)),
    # slices=dict(t=range(62, 64)),
    delayed=delayed_,
):
    p.handle_message({**msg, "image_type": "science"})
del msg

In [None]:
display_image(nd2.get_frame_2D(t=61, v=8), scale=0.99, downsample=4)

In [None]:
display_image(
    image.power_law_composite([nd2.get_frame_2D(t=61, v=8, c=c) for c in range(2)]),
    scale=0.99,
    downsample=4,
)

In [None]:
display_image(
    np.mean([nd2.get_frame_2D(t=61, v=8, c=c) for c in range(2)], axis=0),
    scale=0.99,
    downsample=4,
)

In [None]:
%%time
for msg in readers.send_eaton_fish(
    fish_filename,
    # slices=dict(t=None, v=range(8, 10)),
    # slices=dict(t=[1,2], v=[8,9]),
    # slices=dict(v=range(8,18)),
    delayed=delayed_,
):
    p.handle_message({**msg, "image_type": "fish_barcode"})
del msg

In [None]:
p.handle_message({"type": "done"})

In [None]:
{
    k: f.result().result()
    for k, f in p.measurements.writers.items()
    if f.result().status != "finished"
}

In [None]:
for store in p._stores:
    dask.distributed.fire_and_forget(store.writers.values())
    store.writers.clear()

In [None]:
fish_crops[:, 14, :, 206].squeeze().shape

In [None]:
display_image(
    image.unstack_multichannel(fish_crops[1:, 14, :, 206].squeeze().swapaxes(0, 1)),
    scale=1,
)  # , colors=fish_colors)

# Science analysis

In [None]:
%%time
crops = readers.ZarrSlicer(
    output_dir / "crops",
    r"fov=(?P<v>\d+)/channel=(?P<c>[^/]+)/t=(?P<t>\d+)",
    files=False,
    recursive=True,
)

In [None]:
%%time
segmentation_masks = readers.ZarrSlicer(
    output_dir / "segmentation_masks",
    r"fov=(?P<v>\d+)/t=(?P<t>\d+)",
    files=False,
    recursive=True,
    axis_order="tv",
)

In [None]:
display_image(
    image.unstack(image.crop_to_mask(crops[:, 8, "YFP-EM", 33].squeeze())), scale=1
)

In [None]:
crops[:, 8, "YFP-EM", 3].squeeze().shape

In [None]:
crops[:, 8, :, 7].squeeze().swapaxes(0, 1).shape

In [None]:
image.unstack_multichannel(crops[:, 8, :, 7].squeeze().swapaxes(0, 1)).shape

In [None]:
display_image(
    image.unstack_multichannel(crops[:, 8, :, 7].squeeze().swapaxes(0, 1)),
    scale=1,
)

In [None]:
display_image(
    image.unstack_multichannel(crops[:, 8, :, 7].squeeze().swapaxes(0, 1)),
    scale=1,
)

In [None]:
segmentation_masks[:, 8, 7].squeeze().shape

In [None]:
np.

In [None]:
x = skimage.color.label2rgb(
    np.nan_to_num(segmentation_masks[:, 8, 7].squeeze()).astype(np.uint16)
)

In [None]:
x.shape

In [None]:
np.swapaxes(x, -3, -2).shape

In [None]:
y = np.swapaxes(x, -3, -2)

In [None]:
np.moveaxis(x, 0, 1).shape

In [None]:
y.shape

In [None]:
y.reshape(-1, *y.shape[2:]).shape

In [None]:
def unstack2(ary, axis=1):
    ary = np.moveaxis(ary, 0, axis - 1)
    return ary.reshape(*ary.shape[: axis - 1], -1, *ary.shape[axis + 1 :])


display_image(unstack2(x, axis=2), scale=1)

In [None]:
skimage.color.label2rgb(
    np.nan_to_num(segmentation_masks[:, 8, 7].squeeze()).astype(np.uint16)
).rollaxis(1).shape

In [None]:
skimage.color.label2rgb(
    np.nan_to_num(segmentation_masks[:, 8, 7].squeeze()).astype(np.uint16)
).shape

In [None]:
image.unstack(
    skimage.color.label2rgb(
        np.nan_to_num(segmentation_masks[:, 8, 7].squeeze()).astype(np.uint16)
    ),
    axis=2,
).shape

In [None]:
image.unstack_multichannel(
    skimage.color.label2rgb(
        np.nan_to_num(segmentation_masks[:, 8, 7].squeeze()).astype(np.uint16)
    )
).shape

In [None]:
crops[:, 8, :, 417].squeeze().shape

In [None]:
image.unstack(
    crops[:, 8, :, 417].squeeze(),
    axis=3,
).shape

In [None]:
image.unstack(
    image.unstack(
        crops[:, 8, :, 417].squeeze(),
        axis=3,
    )
).shape

In [None]:
display_image(
    image.unstack(
        image.unstack(
            crops[:, 8, :, 417].squeeze(),
            axis=3,
        )
    ),
    scale=1,
    downsample=4,
)

In [None]:
crops[:, 8, :, 417].squeeze().swapaxes(0, 1).shape

In [None]:
image.unstack_multichannel(crops[:, 8, :, 417].squeeze().swapaxes(0, 1), axis=2).shape

In [None]:
image.unstack_multichannel(crops[:, 8, :, 417].squeeze().swapaxes(0, 1), axis=1).shape

In [None]:
image.unstack(crops[:, 8, :, 417].squeeze(), axis=1).shape

In [None]:
display_image(
    image.unstack_multichannel(
        crops[10:, 8, :, 417].squeeze().swapaxes(0, 1),
    ),
    downsample=4,
)

In [None]:
display_image(
    image.unstack(
        skimage.color.label2rgb(
            np.nan_to_num(segmentation_masks[:, 8, 417].squeeze()).astype(np.uint16)
        ),
        axis=2,
    ),
    scale=1,
)

In [None]:
display_image(
    image.unstack(
        skimage.color.label2rgb(
            np.nan_to_num(segmentation_masks[:, 8, 417].squeeze()).astype(np.uint16)
        ),
        axis=2,
    ),
    scale=1,
)

In [None]:
%%time
dataset = ds.dataset(output_dir / "measurements", format="parquet", partitioning="hive")
# df = dataset.to_table(filter=ds.field("position") == 14).to_pandas()
# df = dataset.to_table().to_pandas().sort_values("t")
df = (
    pl.scan_pyarrow_dataset(dataset)
    .sort(["fov_num", "roi", "t", "channel", "label"])
    .collect()
)

In [None]:
df

In [None]:
df.write_parquet(output_dir / "measurements.parquet")

In [None]:
df.filter(
    pl.col("fov_num") == 8,
    pl.col("roi") == 211,
    pl.col("label") == 1,
    pl.col("t").is_between(15, 200),
).to_pandas().set_index("t").plot(x="t", y="intensity_mean", hue="channel")

In [None]:
df.filter(
    pl.col("fov_num") == 8,
    pl.col("roi") == 212,
    pl.col("label") == 1,
    pl.col("t").is_between(15, 200),
).to_pandas().hvplot("t", "intensity_mean", by="channel")

In [None]:
df.filter(
    pl.col("fov_num") == 8,
    pl.col("roi") == 209,
    pl.col("label") == 1,
    pl.col("t").is_between(15, 200),
).to_pandas().hvplot("t", "intensity_mean", by="channel")

In [None]:
display_image(
    image.unstack(
        skimage.color.label2rgb(
            np.nan_to_num(segmentation_masks[:, 8, 209].squeeze()).astype(np.uint16)
        ),
        axis=2,
    ),
    scale=1,
)

In [None]:
display_image(
    image.unstack_multichannel(
        crops[10:, 8, :, 209].squeeze().swapaxes(0, 1),
    ),
    downsample=1,
)

In [None]:
x = crops[10:, 8, :, 209].squeeze().swapaxes(0, 1)

In [None]:
x.shape

In [None]:
y = image.power_law_composite(x)

In [None]:
np.nanmax(x[0])

In [None]:
display_image(image.unstack(x[2], axis=2), scale=0.99)

In [None]:
display_image(image.unstack(y, axis=2), scale=0.99)

In [None]:
np.nansum(x, axis=0).shape

In [None]:
display_image(image.unstack(np.nansum(x, axis=0), axis=2), scale=0.99)

In [None]:
df.filter(
    pl.col("fov_num") == 8, pl.col("roi") == 403, pl.col("channel") == "YFP-EM"
).to_pandas()

# Export

In [None]:
%%time
dataset = ds.dataset(output_dir / "measurements", format="parquet", partitioning="hive")
# df = dataset.to_table(filter=ds.field("position") == 14).to_pandas()
# df = dataset.to_table().to_pandas().sort_values("t")
df = (
    pl.scan_pyarrow_dataset(dataset)
    .sort(["fov_num", "roi", "t", "channel", "label"])
    .collect()
)

In [None]:
df["label"].value_counts(sort=True)

In [None]:
df2 = (
    df.filter(pl.col("t").is_between(15, 200))
    .filter(pl.col("label").max().over(["fov_num", "roi"]) <= 12)
    .with_columns(
        reversed_label=(
            pl.col("label").max().over(["fov_num", "roi", "t"]) - pl.col("label") + 1
        )
    )
    .join(df_rois.select("fov_num", "roi", "trench_set"), on=["fov_num", "roi"])
    .with_columns(
        label=pl.when(pl.col("trench_set") % 2 == 1)
        .then(pl.col("label"))
        .otherwise(pl.col("reversed_label"))
    )
    .select(pl.all().exclude("reversed_label"))
)

In [None]:
df2.filter(
    pl.col("fov_num") == 8,
    pl.col("roi") == 328,
    pl.col("label") == 1,
).to_pandas().hvplot("t", "intensity_mean", by="channel")

In [None]:
display_image(
    image.unstack(
        skimage.color.label2rgb(
            np.nan_to_num(segmentation_masks[15:200, 8, 328].squeeze()).astype(
                np.uint16
            )
        ),
        axis=2,
    ),
    scale=1,
)

In [None]:
display_image(
    image.unstack_multichannel(
        crops[10:, 8, :, 328].squeeze().swapaxes(0, 1),
    ),
    downsample=1,
)

In [None]:
df2["roi"].unique()

In [None]:
%%time
df_rois = pl.scan_pyarrow_dataset(
    ds.dataset(output_dir / "initial_rois", format="parquet", partitioning="hive")
).collect()

In [None]:
df_rois

In [None]:
df = pd.concat(
    {
        7: pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])),
        8: pd.DataFrame(dict(a=[3, 3, 3], b=[4, 5, 6])),
    }
)

In [None]:
df.index.names

In [None]:
levels_to_drop = [idx for idx, name in enumerate(df.index.names) if name is None]
df.droplevel(level=levels_to_drop, axis=0)

In [None]:
levels_to_drop