# Imports

In [None]:
import glob
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import scipy.signal
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
from dask.diagnostics import ProgressBar

ProgressBar().register()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.io.metadata as imd
from paulssonlab.image_analysis import *
from paulssonlab.image_analysis.ui import display_image

In [None]:
hv.extension("bokeh")
# hv.extension("matplotlib")

# Config

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
# filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"
# filename = workflow.SplitFilename(
#     sorted(
#         glob.glob(
#             "/home/jqs1/scratch/jqs1/microscopy/230619/230619_NAO745_repressilators_split.nd2*"
#         )
#     )
# )
filename = workflow.SplitFilename(
    sorted(
        glob.glob(
            # "/home/jqs1/scratch/jqs1/microscopy/230707/230707_repressilators_restart.nd2.split.a*"
            "/home/jqs1/scratch/microscopy/230830/230830_repressilators.nd2.split.*"
        )
    )
)
# filename = "/home/jqs1/scratch/jqs1/microscopy/231101/231101_FP_calibration.nd2"
fish_filename = Path(filename).parent / "FISH/real_run"

In [None]:
nd2 = workflow.get_nd2_reader(filename)
max_t = nd2.sizes["t"]

In [None]:
nd2.sizes

In [None]:
nd2.metadata["channels"]

In [None]:
colors = {
    "BF": "#ffffff",
    # "CFP-EM": "#6fb2e4",
    # "YFP-EM": "#eee461",
    # "RFP-EM": "#c66526",
    "CFP-EM": "#648FFF",
    "YFP-EM": "#FFB000",
    "RFP-EM": "#DC267F",
}

fish_colors = {
    "BF": "#ffffff",
    "GFP": "#f44336",
    "Cy5": "#03a9f4",
    # "Cy7": "#ffeb3b"
    "Cy7": "#8bc34a",
}

In [None]:
cluster = SLURMCluster(
    queue="transfer",
    walltime="02:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Inventory

In [None]:
filenames = client.submit(
    glob.glob, "/home/jqs1/research.files/Personal_Folders/Noah/23*/*.nd2"
).result()

In [None]:
filenames

In [None]:
%%time
md = {}
for filename in tqdm(filenames):
    sizes, channels = client.submit(get_nd2_metadata, filename).result()
    md[filename] = (sizes, channels)

In [None]:
selected_filenames = [
    "/home/jqs1/research.files/Personal_Folders/Noah/230127/initial_growth.nd2",
    "/home/jqs1/research.files/Personal_Folders/Noah/230203/230203_circuits.nd2",
    "/home/jqs1/research.files/Personal_Folders/Noah/230131/230131_growth_5min.nd2",
    "/home/jqs1/research.files/Personal_Folders/Noah/230131/230131_growth.nd2",
    "/home/jqs1/research.files/Personal_Folders/Noah/231101/231101_FP_calibration.nd2",
    "/home/jqs1/research.files/Personal_Folders/Noah/230915/test.nd2",
    "/home/jqs1/research.files/Personal_Folders/Noah/230125/overnight_growth.nd2",
]

In [None]:
{k: v[0] for k, v in md.items() if k in selected_filenames}

In [None]:
{k: v[1] for k, v in md.items() if "Phase-Fluor" in v[1]}

In [None]:
md = imd.parse_nd2_file_metadata(
    "/home/jqs1/scratch/microscopy/231101/231101_FP_calibration.nd2"
)

In [None]:
md.keys()

In [None]:
md["image_calibration"]["SLxCalibration"]["sObjective"]

In [None]:
%%time
metadata = {}
for filename in tqdm(selected_filenames):
    metadata[filename] = client.submit(imd.parse_nd2_file_metadata, filename).result()

In [None]:
metadata.keys()

In [None]:
{k: v["image_calibration"]["SLxCalibration"]["sObjective"] for k, v in metadata.items()}

In [None]:
metadata[""]

In [None]:
client.submit(glob.glob, "")

In [None]:
filename = ""

In [None]:
def nd2_to_zarr(fovs=None):
    workflow.get_nd2_frame(
        filename, position=position, channel=channel, t=t, dark=dark, flat=flat
    )

In [None]:
def get_nd2_metadata(filename):
    nd2 = workflow.get_nd2_reader(filename)
    return nd2.sizes, nd2.metadata["channels"]

In [None]:
sizes, channels = client.submit(get_nd2_metadata, filename).gather()

In [None]:
filename_prefix = "/home/jqs1/research.files/Personal_Folders/"
selected_filenames = """Noah/231101/231101_FP_calibration.nd2
Noah/230131/230131_growth_5min.nd2
Noah/230125/overnight_growth.nd2
Noah/230203/230203_circuits.nd2
Daniel/FISH_Paper_Data/lDE20_Data/2023-01-14_lDE20_Run_9/Experiment.nd2
Daniel/FISH_Paper_Data/lDE15_Data/2021-10-21_lDE15_Final_1/experiment.nd2
Daniel/FISH_Paper_Data/Isolates/2023-02-11_lpxK_LpxC_AB/Experiment.nd2
Carlos/Ti5/LCS3_run1/Experiment.nd2
Carlos/Ti5/LCS3_run2/Experiment.nd2
Carlos/Ti5/08072023_lcs2/GlycerolArabinoseMedia.nd2
Carlos/Ti5/08072023_lcs2/GlycerolMedia3hr.nd2
Carlos/Ti5/2019_02_07/AndersonPromoters_Phase.nd2
Luis/Imaging_Data/Bsubtilis_DegronRapamycin/2022-05-10_AF337-AF339_Ti6/tlapse-1.nd2
Luis/Imaging_Data/Bsubtilis_BarcodesTesting/2023-10-31_lLAG2_AF555-AllCycles/Experiment.nd2
Daniel/FISH_Paper_Data/lDE26_Data/2023-03-11_lDE26_Run_1/Experiment.nd2
Luis/Imaging_Data/Ecoli_Libraries/2022-08-19_MM-Ti5-lDE24/2022-08-19_MM-Ti5-lDE24.nd2
Raquel/Results/P1 grant/Gilmore/2023_07_12 RF320 Rifam/RF320.nd2
Raquel/Results/P1 grant/Gilmore/2023_07_20 RF320 Rifam/64_128_256_512_.nd2
Raquel/Results/P1 grant/Gilmore/2021_04_20 Enterococcus RF235/RF235.nd2
Raquel/Results/P1 grant/Gilmore/2023_08_01 RF320 Dapto/Control_128_256_512.nd2
Raquel/Results/P1 grant/Gilmore/2023_08_11 Tnlibrary Dapto dyes/ND2_growth/growth.nd2
Raquel/Results/P1 grant/Gilmore/2023_08_24 EF daptomycin/growth.nd2
Raquel/Results/P1 grant/Gilmore/2023_10_4 RF320/ND2_growth/Experiment_growth.nd2
Raquel/Results/P1 grant/Gilmore/2023_12_07 RF157 MM/ND2/MHCA_Dapto.nd2
Raquel/Results/P1 grant/Gilmore/2020_09_28 Enterococcus mcherry gfp/Enterococcus mcherry gfp.nd2
Raquel/Results/PhoPQ/Mother machine experiments/2021_02_14 RF230 N10 Sin N50 Sin Snake no treatment (importante)/N10 Sin N50 Sin Snake.nd2
Raquel/Results/P1 grant/Hooper/2023_03_09 RF322/ND2/Experiment.nd2
Raquel/Results/P1 grant/Hooper/2023_03_30 RF322/ND2/Experiment.nd2
Raquel/Results/P1 grant/Hooper/2023_09_06 RF322 pyocyanin/ND2/Control_Pyo_Naf_Pyo+Naf003.nd2
Mengyu/microscopy/2022-04-15 Ti6/growth-lysogen-lambda2903_A01.nd2
Mengyu/microscopy/2022-03-21 Ti6/1 growth/growth_start_with_inducer.nd2
Mengyu/microscopy/2022-03-25 Ti6/1 growth/growth001.nd2
Mengyu/microscopy/2022-12-26 Jurkat cell growth/jurkat-growth.nd2
Mengyu/microscopy/2023-05-12 Jurkat cell growth (w 10ng per ml IL-7)/growth.nd2""".split(
    "\n"
)
selected_filenames = [f"{filename_prefix}{filename}" for filename in selected_filenames]

In [None]:
selected_filenames

In [None]:
%%time
metadata = {}
for filename in tqdm(selected_filenames):
    metadata[filename] = client.submit(imd.parse_nd2_file_metadata, filename).result()

In [None]:
%%time
metadata2 = {}
for filename in tqdm(selected_filenames):
    metadata2[filename] = client.submit(get_nd2_metadata, filename).result()

In [None]:
for filename in selected_filenames:
    print(metadata2[filename][0])

In [None]:
metadata[
    "/home/jqs1/research.files/Personal_Folders/Noah/230125/overnight_growth.nd2"
].keys()

In [None]:
x = metadata[
    "/home/jqs1/research.files/Personal_Folders/Noah/230125/overnight_growth.nd2"
]["image_metadata"]

In [None]:
x

In [None]:
{k: v["image_calibration"]["SLxCalibration"]["sObjective"] for k, v in metadata.items()}

# ND2 to Zarr conversion

In [None]:
output_filename = (
    f"/home/jqs1/group/jqs1/microscopy/for_janelia/{Path(str(filename)).name}"
)

In [None]:
output_filename

In [None]:
%%time
readers.convert_nd2_to_array(
    nd2,
    output_filename,
    file_axes=[],
    dataset_axes=["fov", "channel"],
    slices=dict(fov=[11], t=slice(None)),
    format="zarr",
)

In [None]:
x = h5py.File(
    "/home/jqs1/group/jqs1/microscopy/for_janelia/230830_repressilators.nd2.split.aa/fov=22.hdf5"
)

In [None]:
y = zarr.convenience.open(output_filename + ".zarr")

In [None]:
display_image(y["fov=11/channel=CFP-EM"][0, 0], scale=0.99)

In [None]:
display_image(x["channel=CFP-EM"][0, 0], scale=0.99)

# Handler

In [None]:
def crop_rois(img, rois):
    crops = {}
    # TODO: the islice is just for testing (we only deal with three trenches for FOV), otherwise every dask task takes a long time
    # for i, crop in it.islice(geometry.iter_roi_crops(img, rois), 100):
    for i, crop in geometry.iter_roi_crops(img, rois):
        crops[i] = crop
    return crops


def segment_crops(crops):
    masks = {}
    for i, crop in crops.items():
        masks[i] = segmentation.watershed.segment(crop)
    return masks


# TODO: this is really boilerplatey, also we want finer task granularity than doing a whole FOV at once
# def measure_crops(label_images, intensity_images):
#     keys = label_images.keys() & intensity_images.keys()
#     return {k: measure_crop(label_images[k], intensity_images[k]) for k in keys}
def measure_crops(intensity_images):
    keys = intensity_images.keys()
    return {k: measure_crop(intensity_images[k]) for k in keys}


# def measure_crop(label_image, intensity_image):
# return pd.DataFrame(
#     skimage.measure.regionprops_table(
#         label_image,
#         intensity_image,
#         properties=(
#             "label",
#             "intensity_mean",
#         ),
#     )
# ).set_index("label")
def measure_crop(intensity_image):
    centerline = intensity_image[:, intensity_image.shape[1] // 2]
    return pd.Series(
        {
            # "p1": np.percentile(intensity_image, 1),
            # "p50": np.median(intensity_image),
            "p90": np.percentile(intensity_image, 90),
            # "p99": np.percentile(intensity_image, 99),
            # "mean": np.mean(intensity_image),
            # "centerline_mean": np.mean(centerline),
            # "centerline_median": np.median(centerline),
        },
        name="value",
    ).rename_axis(index="observable")


def measure_mask_crops(label_images):
    return {k: measure_mask_crop(v) for k, v in label_images.items()}


def measure_mask_crop(label_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            properties=(
                "label",
                "area",
                "axis_major_length",
                "axis_minor_length",
                "orientation",
                "centroid",
            ),
        )
    ).set_index("label")


def write_parquet(output_dir, measurements, position, t):
    df = pd.concat(
        {
            channel: pd.concat(channel_df, names=["roi_idx"])
            for channel, channel_df in measurements.items()
        },
        names=["channel"],
    ).reset_index()
    df["position"] = np.array(position).astype(np.uint16)
    df["t"] = np.array(t).astype(np.uint16)
    pq.write_to_dataset(
        pa.Table.from_pandas(df, preserve_index=False),
        Path(output_dir) / "measurements",
        partition_cols=["position", "t"],
        existing_data_behavior="delete_matching",
    )


def stack_dict(d, size=None, cval=0):
    if size is None:
        size = max(d.keys()) + 1
    shape = next(iter(d.values())).shape
    null = np.full(shape, cval)
    return [d.get(idx, null) for idx in range(size)]


def _pad(ary, shape, cval=0):
    return np.pad(
        ary,
        [(0, max(goal - current, 0)) for goal, current in zip(shape, ary.shape)],
        constant_values=cval,
    )


def write_zarr(filename, crops, t, max_t, channels, cval=0, dtype=np.uint16):
    store = zarr.DirectoryStore(filename)  # DirectoryStoreV3(filename)
    if not filename.exists():
        num_rois = max(crops[channels[0]].keys()) + 1
        num_channels = len(channels)
        max_shape = np.max([crop.shape for crop in crops[channels[0]].values()], axis=0)
        shape = (num_rois, max_t, num_channels, *max_shape)
        # chunks = (5, 1, num_channels, None, None)
        chunks = (20, 1, num_channels, None, None)
        ary = zarr.open_array(
            store,
            mode="a",
            zarr_version=2,
            shape=shape,
            chunks=chunks,
            fill_value=cval,
            dtype=dtype,
        )
    else:
        ary = zarr.open_array(store, mode="a", zarr_version=2)
        max_shape = ary.shape[-2:]
    stack = np.array(
        [
            stack_dict(
                {
                    idx: _pad(crop.astype(dtype), max_shape)
                    for idx, crop in crops[channel].items()
                },
                size=ary.shape[0],
            )
            for channel in channels
        ]
    ).swapaxes(0, 1)
    ary[:, t, ...] = stack

In [None]:
# output_dir = Path(filename).parent / "for_ranit_3fovs_uint16_v2"
# output_dir.mkdir(exist_ok=True)

In [None]:
# #segmentation_channels = ["RFP-PENTA"]
# segmentation_channels = ["RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"]
# trench_detection_channels = segmentation_channels # channel for trench detection, almost always same as segmentation_channel
# # measure_channels = ["RFP-PENTA", "YFP-DUAL"]
# measure_channels = ["Phase-Fluor", "RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"]
# fish_channels = ["RFP-PENTA", "Cy5-PENTA", "Cy7"]

In [None]:
segmentation_channels = ["RFP-EM"]
# measurement_channels = ["CFP-EM", "YFP-EM", "RFP-EM"]
measurement_channels = []
crop_channels = nd2.metadata["channels"]
# channel_colors = [colors[channel] for channel in measurement_channels]
width_to_pitch_ratio = 1.4 / 3.5  # for debugging: 2.2 / 3.5
k1 = 8.5e-10
center_y = -800
center_x = 0
center = image.center_from_shape((nd2.sizes["y"], nd2.sizes["x"])) - np.array(
    [center_x, center_y]
)

In [None]:
def get_frame_func(
    filename, position, channel, t, k1=k1, center=center, dark=None, flat=None
):
    return image.correct_radial_distortion(
        np.asarray(
            workflow.get_nd2_frame(
                filename, position=position, channel=channel, t=t, dark=dark, flat=flat
            )
        ),
        k1=k1,
        input_center=center,
    ).astype(np.uint16)

In [None]:
%%time
img0 = get_frame_func(filename, 11, segmentation_channels[0], 0)
# TODO: replace with calculation that doesn't require processing an image
image_limits = geometry.get_image_limits(img0.shape)

In [None]:
display_image(img0, scale=0.99, downsample=4)

In [None]:
["Phase-Fluor", "RFP-EM", "GFP-EM", "YFP-EM", "CFP-EM"]

In [None]:
display_image(get_frame_func(filename, 11, "GFP-EM", 0), scale=0.99, downsample=4)

In [None]:
display_image(get_frame_func(filename, 11, "YFP-EM", 0), scale=0.99, downsample=4)

In [None]:
display_image(get_frame_func(filename, 11, "CFP-EM", 20), scale=0.99, downsample=4)

In [None]:
%%time
imgs = [
    get_frame_func(filename, 50, channel, 0)
    for channel in tqdm(nd2.metadata["channels"][1:])
]

In [None]:
%%time


def combine_channels_for_segmentation(imgs, same_dtype=True):
    dtype = imgs[0].dtype
    imgs = [
        imgs[0],
        *(skimage.exposure.match_histograms(img, imgs[0]) for img in imgs[1:]),
    ]
    combined = np.sum(np.stack(imgs) / len(imgs), axis=0)
    if same_dtype:
        combined = combined.astype(dtype)
    return combined


combined = combine_channels_for_segmentation(imgs)

In [None]:
combined.max()

In [None]:
display_image(combined, scale=0.99, downsample=1)

In [None]:
%%time
diag = util.tree()
rois, info = trench_detection.find_trenches(
    img0,
    width_to_pitch_ratio=width_to_pitch_ratio,
    join_info=False,
    diagnostics=diag,
    # pitch=16.173741362290226
)
angle = info["angle"]
pitch = info["pitch"]
(angle, pitch)

In [None]:
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
diag["bboxes"]

In [None]:
def process_fov(
    get_frame_func,
    position,
    ts,
    output_dir,
    segmentation_channels,
    measurement_channels,
    crop_channels,
    image_limits,
    write_full_frames=True,
    find_trenches_kwargs={},
    dark=None,
    flats=None,
    delayed=True,
):
    delayed = util.get_delayed(delayed)
    channels = [
        segmentation_channels[0],
        *(
            set(measurement_channels).union(crop_channels)
            - set([segmentation_channels[0]])
        ),
    ]
    measurement_channels = measurement_channels
    rois = None
    shifts = {}
    write_tasks = []
    for prev_t, t in tqdm(list(zip(it.chain([None], ts[:-1]), ts))):
        segmentation_img = delayed(get_frame_func)(
            position, segmentation_channels[0], t
        )
        full_frames = {segmentation_channels[0]: segmentation_img}
        if rois is None:
            rois = delayed(trench_detection.find_trenches)(
                segmentation_img, **{**dict(join_info=True), **find_trenches_kwargs}
            )
            shifts[t] = np.array([0, 0])
            initial_drift_features = delayed(drift.get_drift_features)(
                segmentation_img, rois, shifts[t]
            )
        else:
            shifts[t] = delayed(drift.find_feature_drift)(
                initial_drift_features,
                segmentation_img,
                rois,
                initial_shift2=shifts[prev_t],
            )
        shifted_rois = delayed(geometry.filter_rois)(
            delayed(geometry.shift_rois)(rois, shifts[t]), image_limits
        )
        crops = {}
        measurements = {}
        for channel in channels:
            if channel == segmentation_channels[0]:
                crops[channel] = delayed(crop_rois)(segmentation_img, shifted_rois)
                # mask_crops = delayed(segment_crops)(crops[channel])
                # mask_measurements = delayed(measure_mask_crops)(mask_crops)
            else:
                img = delayed(get_frame_func)(position, channel, t)
                full_frames[channel] = img
                crops[channel] = delayed(crop_rois)(img, shifted_rois)
            if channel in measurement_channels:
                # measurements[channel] = delayed(measure_crops)(mask_crops, crops[channel])
                measurements[channel] = delayed(measure_crops)(crops[channel])
        metadata = dict(shifts=shifts)
        if measurements:
            write_tasks.append(
                delayed(write_parquet)(output_dir, measurements, position, t)
            )
        crops_to_write = {
            channel: channel_crops
            for channel, channel_crops in crops.items()
            if channel in crop_channels
        }
        write_tasks.append(
            delayed(write_zarr)(
                output_dir / f"crops_v={position}.zarr",
                crops_to_write,
                t,
                max_t,
                crop_channels,
            )
        )
        full_frames_to_write = {
            channel: {0: frame} for channel, frame in full_frames.items()
        }
        write_tasks.append(
            delayed(write_zarr)(
                output_dir / f"full_frames_v={position}.zarr",
                full_frames_to_write,
                t,
                max_t,
                crop_channels,
            )
        )
        # TODO: rois, metadata
    return write_tasks

In [None]:
%%time
# %%pyinstrument
ts = np.arange(max_t)
# ts = np.arange(2)
res = []
for position in trange(11, 14):
    res.append(
        process_fov(
            partial(get_frame_func, filename),
            position,
            ts,
            output_dir,
            segmentation_channels,
            measurement_channels,
            crop_channels,
            image_limits,
            find_trenches_kwargs=dict(
                angle=angle, pitch=pitch, width_to_pitch_ratio=width_to_pitch_ratio
            ),
            delayed=True,
        )
    )

In [None]:
%%time
futures = [client.compute(x) for x in tqdm(res)]

In [None]:
del futures

In [None]:
client.gather(futures)

In [None]:
errored = [e for fov in futures if (e := [f for f in fov if f.status == "error"])]