# Imports

In [None]:
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path
import pickle
import dask
import distributed
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from tqdm.auto import tqdm

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.image_analysis.new as new
from paulssonlab.image_analysis import *

In [None]:
# %load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Handler

In [None]:
segmentation_channel = "RFP-Penta"
trench_detection_channel = segmentation_channel  # channel for trench detection, almost always same as segmentation_channel
measure_channels = ["RFP-Penta", "YFP-DUAL"]
fish_channels = ["RFP-Penta", "Cy5-PENTA", "Cy7"]

## Load outputs from pickle

In [None]:
# pickle_filename = "/home/jqs1/group/221108rbsdeglibrary_1.pickle"
pickle_filename = "/home/jqs1/group/221108rbsdeglibrary_1_table.pickle"

In [None]:
%%time
with open(pickle_filename, "rb") as f:
    table, array = pickle.load(f)

# Helper functions

In [None]:
def _rename_column(col):
    if col[0] == "mask_measurements":
        return col[1]
    elif col[0] == "measurements":
        return "/".join(col[1:])
    else:
        return "/".join(col)


def reformat_table(table, flatten_column_names=False):
    prefixes = sorted(set(k[0] for k in table.keys()))
    df = pd.concat(
        {
            prefix: pd.concat(
                {
                    k[1:]: pd.concat(table[k], names=["roi"])
                    for k in table.keys()
                    if k[0] == prefix
                },
                names=["fov", "t", "channel"],
            ).unstack("channel")
            for prefix in prefixes
        },
        axis=1,
    )
    if flatten_column_names:
        # replace MultiIndex with Index of slash-separated names like "GFP-PENTA/mean_intensity"
        df.columns = [_rename_column(col) for col in df.columns.values]
    return df

In [None]:
%%pyinstrument
reformat_table(
    {
        k: v
        for k, v in table.items()
        if k[0] in ("measurements", "mask_measurements") and k[2] < 3
    },
    flatten_column_names=True,
)

In [None]:
def stack_crops(array, prefix, fov, channel):
    keys = sorted(
        [
            k
            for k in array.keys()
            if len(k) == 4 and k[:2] == (prefix, fov) and k[3] == channel
        ]
    )
    trenches = reduce(operator.and_, [array[k].keys() for k in keys])
    crops = {}
    for trench in list(trenches):
        crops[trench] = np.stack([array[k][trench] for k in keys])
    return crops


def unstack(ary):
    return np.swapaxes(ary, 0, 1).reshape(ary.shape[1], -1)


def pad_and_stack(arys, fill_value=0):
    shape = np.max([ary.shape for ary in arys], axis=0)
    return np.stack(
        [
            np.pad(
                ary,
                ((shape[0] - ary.shape[0], 0), (shape[1] - ary.shape[1], 0)),
                constant_values=fill_value,
            )
            for ary in arys
        ]
    )


def pad_unstack(arys):
    return unstack(pad_and_stack(arys))

# Reformat table

In [None]:
%%time
all_measurements = reformat_table(
    {
        k: v
        for k, v in table.items()
        if k[0] in ("measurements", "mask_measurements")  # and k[2] < 3
    },
    flatten_column_names=True,
)

In [None]:
all_measurements

In [None]:
pickle_filename = "/home/jqs1/group/221108rbsdeglibrary_1_table_reformatted.pickle"

In [None]:
%%time
with open(pickle_filename, "wb") as f:
    pickle.dump((all_measurements, trenches), f)

In [None]:
%%time
with open(pickle_filename, "rb") as f:
    all_measurements, trenches = pickle.load(f)

# Growth rates

In [None]:
%%time
one_trench = all_measurements.xs(IDX[:, :, 1000, :] ,drop_level=False)

In [None]:
%%time
mothers = one_trench.xs(IDX[:, :, :, 1], drop_level=False)

In [None]:
def track_mother_cell(df):
    # groupby rois
    mothers = df.xs(IDX[:, :, :, 1], drop_level=False)
    lengths = mothers["axis_major_length"].values
    # first cell ID is 1, (0 is used as marker of non-tracked cell segment)
    mother_cell_ids = np.concatenate(([1], 1 + np.cumsum(lengths[1:] < lengths[:-1])))
    cell_ids = np.zeros(len(df), dtype=np.uint64)
    cell_ids[df.index.get_locs(IDX[:, :, :, 1])] = mother_cell_ids
    return df.assign(cell_id=cell_ids)

In [None]:
one_trench_tracked = track_mother_cell(one_trench)
one_trench_tracked

In [None]:
one_trench_tracked[one_trench_tracked["cell_id"] != 0].hvplot.scatter("t", "axis_major_length", by="cell_id", cmap="Category20", legend=False)

In [None]:
%%time
all_tracked = all_measurements.groupby("roi", group_keys=False).apply(track_mother_cell)

In [None]:
all_tracked

In [None]:
all_tracked.xs(IDX[:,:,0,1]).hvplot.scatter("t", "axis_major_length", by="cell_id", cmap="Category20", legend=False)