# Imports

In [None]:
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import zarr
import dask
from dask import delayed
import distributed
from distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
import matplotlib.pyplot as plt
import holoviews as hv
import hvplot.pandas
from tqdm.auto import tqdm
from functools import partial, reduce
import operator
import itertools as it
from collections import namedtuple
import nd2reader
import re
import os
from pathlib import Path
import skimage.measure
import pickle

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import *
import paulssonlab.image_analysis.new as new

In [None]:
#%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="4GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(20)

In [None]:
cluster.adapt(maximum=20)

# Handler

In [None]:
drift_correction_channel = "Phase-Fluor"
segmentation_channel = "RFP-PENTA"
trench_detection_channel = segmentation_channel  # channel for trench detection, almost always same as segmentation_channel
measure_channels = ["RFP-PENTA", "GFP-PENTA"]
fish_channels = ["RFP-Penta", "Cy5-PENTA", "Cy7"]

In [None]:
import logging


class Pipeline:
    def __init__(self, output_dir):
        self.logger = logging.getLogger("Pipeline")
        self.output_dir = Path(output_dir)
        self.state = {}
        self.array = {}
        self.table = {}

    def delayed(self, func, *args, **kwargs):
        # TODO:
        # log exceptions
        # log warnings (deduplicated, count instances)
        # optionally retry with diag if func takes "diagnostics" argument
        # log benchmarking/profiling? or collect stats, only log outliers (+ call arguments)
        return dask.delayed(func, *args, **kwargs)

def crop_trenches(img, trenches):
    crops = {}
    # TODO: the islice is just for testing (we only deal with three trenches for FOV), otherwise every dask task takes a long time
    #for i, crop in it.islice(new.image.iter_crops(img, trenches), 3):
    for i, crop in new.image.iter_crops(img, trenches):
        crops[i] = crop
    return crops


def segment_trenches(crops):
    masks = {}
    for i, crop in crops.items():
        try:
            masks[i] = trench_segmentation.segment(crop)
        except:
            pass
    return masks

# TODO: this is really boilerplatey, also we want finer task granularity than doing a whole FOV at once
def measure_crops(label_images, intensity_images):
    keys = label_images.keys() & intensity_images.keys()
    return {k: measure_crop(label_images[k], intensity_images[k]) for k in keys}

def measure_crop(label_image, intensity_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            intensity_image,
            properties=(
                "label",
                "intensity_mean",
            ),
        )
    ).set_index("label")

def measure_mask_crops(label_images):
    return {k: measure_mask_crop(v) for k, v in label_images.items()}

def measure_mask_crop(label_image):
    return pd.DataFrame(
        skimage.measure.regionprops_table(
            label_image,
            properties=(
                "label",
                "area",
                "axis_major_length",
                "axis_minor_length",
                "orientation",
                "centroid",
            ),
        )
    ).set_index("label")


# TODO: use a namedtuple (or typing.NamedTuple, or dataclass) for keys so that fields are named
def handle_image(pipeline, msg):
    image = msg["image"]
    metadata = msg["metadata"]
    fov_num = metadata["fov_num"]
    t = metadata["t"]
    channel = metadata["channel"]
    raw_key = ("raw", fov_num, t, channel)
    # store raw image (in production, we won't do this, we will only store crops as we do below)
    pipeline.array[raw_key] = image
    # TODO: we need a way to store per-frame metadata and write it to disk
    trenches_key = (
        "trenches",
        fov_num,
    )
    trenches = pipeline.table.get(trenches_key)
    # check if we have done trench detection for this FOV
    if trenches is None and channel == trench_detection_channel:
        # if not, find trenches and save the resulting table
        trenches = pipeline.delayed(new.image.find_trench_bboxes)(image, peak_func=trench_detection.peaks.find_peaks)
        pipeline.table[trenches_key] = trenches
    # this list keeps track of all the raw frames that need to be cropped
    # frames for multiple channels will accumulate in this list until we get a frame for trench_detection_channel
    # if we have already processed such a frame, then keys_to_crop will contain only the current frame (raw_key)
    keys_to_crop = pipeline.state.setdefault(("keys_to_crop", fov_num), [])
    keys_to_crop.append(raw_key)
    # we only can do further processing if we have already detected trenches for this FOV
    if trenches is not None:
        for raw_to_crop in keys_to_crop:
            crop_key = ("crops", *raw_to_crop[1:])
            # save trench crops for every frame in keys_to_crop
            pipeline.array[crop_key] = pipeline.delayed(crop_trenches)(
                pipeline.array[raw_to_crop], trenches
            )
            segmentation_key = ("segmentation", fov_num, t, segmentation_channel)
            segmentation = pipeline.array.get(segmentation_key)
            if segmentation is not None:
                if crop_key[-1] in measure_channels:
                    # if we have segmentation masks for this frame, we can immediately segment only this frame
                    keys_to_measure = [crop_key]
                else:
                    keys_to_measure = []
            else:
                # we don't have a segmentation mask yet, so we need to add to the keys_to_measure list
                keys_to_measure = pipeline.state.setdefault(("keys_to_measure", fov_num, t), [])
                if crop_key[-1] in measure_channels:
                    # we want to measure this frame
                    keys_to_measure.append(crop_key)
                if crop_key[-1] == segmentation_channel:
                    # if this frame's channel is the segmentation channel, run segmentation
                    segmentation = pipeline.delayed(segment_trenches)(
                        pipeline.array[crop_key]
                    )
                    pipeline.array[segmentation_key] = segmentation
                    # once we have the segmentation mask, get measurements for the mask
                    pipeline.table[("mask_measurements", *crop_key[1:],)] = pipeline.delayed(
                        measure_mask_crops
                    )(segmentation)
            segmentation = pipeline.array.get(segmentation_key)
            # if we now have the segmentation mask, try measuring all frames in the keys_to_measure list
            if segmentation is not None:
                for crop_to_measure in keys_to_measure:
                    measurements_key = ("measurements", *crop_to_measure[1:])
                    pipeline.table[measurements_key] = pipeline.delayed(measure_crops)(
                        segmentation, pipeline.array[crop_to_measure]
                    )
                pipeline.state.pop(("keys_to_measure", fov_num, t), None)
        pipeline.state.pop(("keys_to_crop", fov_num), None)


def handle_fish_barcode(pipeline, msg):
    pass # TODO


# we should pick a name that's better/more intuitive than handle_message
def handle_message(pipeline, msg):
    match msg:
        case {"type": "image", **info}:
            match info:
                case {"image_type": "fish_barcode"}:
                    handle_fish_barcode(pipeline, msg)
                case other:
                    handle_image(pipeline, msg)
        case {"type": "nd2_metadata"}:
            print("got metadata") # TODO
        case {"type": "event", **info}:
            print("event", info)
        case {"type": "done"}:
            print("DONE")
        case _:
            # this exception should be caught, we don't want malformed messages to crash the pipeline
            raise ValueError("cannot handle message", msg)

In [None]:
%%time
filename = "/home/jqs1/scratch/jqs1/microscopy/210511/RBS_ramp.nd2"
pipeline = Pipeline("/home/jqs1/scratch/jqs1/microscopy/220718/new_architecture/test1")
for msg in new.readers.send_nd2(
    filename,
    slices=dict(v=[30], t=slice(40,None)),
):
    handle_message(pipeline, msg)

In [None]:
%%time
futures = util.apply_map_futures(client.compute, (pipeline.table, pipeline.array))

In [None]:
%%time
table, array = client.gather(futures)

In [None]:
pickle_filename = "/home/jqs1/group/221102plasmidloss_1.pickle"

In [None]:
%%time
with open(pickle_filename, "wb") as f:
    pickle.dump((table, array), f)

In [None]:
%%time
table2, array2 = pickle.load(pickle_filename)

# Reformat outputs

In [None]:
def reformat_tables(table, prefix):
    if not isinstance(prefix, tuple):
        prefix = (prefix,)
    keys = sorted([k for k in tables.keys() if k[:len(prefix)] == prefix])
    df = pd.concat({k[len(prefix):]: pd.concat(tables[k], names=["trench"]) for k in keys}, names=["fov", "t", "channel"])
    df = df.unstack("channel")
    # replace MultiIndex with Index of slash-separated names like "GFP-PENTA/mean_intensity"
    df.columns = ["/".join(col[::-1]) for col in df.columns.values]
    return df

d = reformat_table(table, "measurements")

In [None]:
reduce?

In [None]:
def stack_crops(array, prefix, fov, channel):
    keys = sorted([k for k in array.keys() if len(k) == 4 and k[:2] == (prefix, fov) and k[3] == channel])
    trenches = reduce(operator.and_, [array[k].keys() for k in keys])
    crops = {}
    for trench in list(trenches):
        crops[trench] = np.stack([array[k][trench] for k in keys])
    return crops

In [None]:
%%time
d = stack_crops(array, "crops", 30, "RFP-PENTA")

In [None]:
a = d[33]
plt.figure(figsize=(20,20))
plt.imshow(np.swapaxes(a, 0, 1).reshape(a.shape[1], -1))

In [None]:
a.reshape(a.shape[1], -1).shape

In [None]:
plt.imshow(a.reshape(a.shape[0] * a.shape[1], a.shape[2]))

In [None]:
plt.imshow(array[('crops', 30, 48, 'RFP-PENTA')][100])

In [None]:
d

In [None]:
d.loc[IDX[:,:,5]]

In [None]:
d[IDX[:,:,0,:]].hvplot.scatter("t", "GFP-PENTA/intensity_mean")

In [None]:
d.reset_index().hvplot("GFP-PENTA/intensity_mean")

In [None]:
d.reset_index("channel")

In [None]:
dd = d.unstack("channel")

In [None]:
dd.columns

In [None]:
dd.columns = ["/".join(col[::-1]) for col in dd.columns.values]

In [None]:
dd

In [None]:
pd.melt(d.reset_index("channel"), id_vars=["channel"])

# Drift correction test

In [None]:
nd2 = nd2reader.ND2Reader("/home/jqs1/scratch/jqs1/microscopy/210511/RBS_ramp.nd2")

In [None]:
nd2.sizes

In [None]:
f1 = nd2.get_frame_2D(v=30,t=0, c=0)
f2 = nd2.get_frame_2D(v=30,t=150, c=0)

In [None]:
from skimage.registration import optical_flow_ilk, optical_flow_tvl1, phase_cross_correlation

In [None]:
phase_cross_correlation(f1, f2, return_error=False)

In [None]:
plt.figure(figsize=(30,30))
plt.imshow(f1-f2)

In [None]:
plt.imshow(f1)