# Imports

In [None]:
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path
import pickle
import dask
import distributed
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import pyarrow as pa
import pyarrow.parquet as pq
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from tqdm.auto import tqdm

IDX = pd.IndexSlice
tqdm.pandas()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.image_analysis.new as new
from paulssonlab.image_analysis import *

In [None]:
# %load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Helper functions

This renames columns with a slightly different convention than my previous `reformat_table` function (mask_measurement column names look like `centroid-0` whereas measurement column names look like `GFP-PENTA/mean_intensity`). It also joins `mask_measurements` together with `measurements`.

In [None]:
def _rename_column(col):
    if col[0] == "mask_measurements":
        return col[1]
    elif col[0] == "measurements":
        return "/".join(col[1:])
    else:
        return "/".join(col)


def reformat_table(table, flatten_column_names=False):
    prefixes = sorted(set(k[0] for k in table.keys()))
    df = pd.concat(
        {
            prefix: pd.concat(
                {
                    k[1:]: pd.concat(table[k], names=["roi"])
                    for k in table.keys()
                    if k[0] == prefix
                },
                names=["fov", "t", "channel"],
            ).unstack("channel")
            for prefix in prefixes
        },
        axis=1,
    )
    if flatten_column_names:
        # replace MultiIndex with Index of slash-separated names like "GFP-PENTA/mean_intensity"
        df.columns = [_rename_column(col) for col in df.columns.values]
    return df

# Fix ROI orientation

Then we need to adjust the `label` index. This step depends on the `trenches` dataframe. We look up each `(fov, roi)` key in the `trenches` dataframe, see what `trench_set` the roi belongs to, and reverses the ordering of the labels for odd-numbered `trench_sets` (e.g., `labels=[1,2,3]` -> `labels=[3,2,1]`)

In [None]:
%%time
# labeling for odd trench_sets need to be inverted
def fix_label_order(df):
    if trenches["trench_set"].loc[df.index[0][:2]] % 2 == 0:
        return df
    else:
        df["label"] = df["label"].max() - df["label"] + df["label"].min()
        return df.sort_values("label")


all_measurements_reordered = (
    all_measurements.reset_index(["label"])
    .groupby(["fov", "roi", "t"], group_keys=False)
    .progress_transform(fix_label_order)
).set_index("label", append=True)

# Load data

If you want to play around with real data, you can load. This pickle dataset already has the above steps (`reformat_table` and `fix_label_order`) applied.

In [None]:
pickle_filename = "/home/jqs1/group/221108rbsdeglibrary_1_table_reformatted2.pickle"

In [None]:
%%time
with open(pickle_filename, "rb") as f:
    all_measurements, trenches = pickle.load(f)

In [None]:
all_measurements

# Growth rates

Tracking associates segmentation mask labels at one timepoint `t` with labels at `t+1`. For each label at `t`, it corresponds one label at `t+1` (representing the same cell at a later time), with two labels at `t+1` (representing its two daughter cells arising from a cell division event), or it is marked as a cell that died/went out of frame. In production we use a more sophisticated linear-programming-based tracking algorithm; I am working on rewriting it and will integrate it with our codebase soon. For testing we will use this simple mock tracking algorithm that assigns `cell_id=1` to the mother cell (the cell with `label=1`) at `t=0`; every time the mother cell shrinks, it increments the mother cell `cell_id` by one. It assigns `cell_id=0` to all non-mother cells. `cell_id=0` is used as a sentinel for untracked cells and they are filtered out for downstream processing. The reason we needed `fix_label_order` above is so that all segmentation labeling is standardized so that the mother cell (at the dead end of each trench) for each `roi` and timepoint `t` gets `label=1`.

The output of any tracking algorithm, including `track_mother_cell`, is to add a `cell_id` column to the measurements dataframe where each unique non-zero positive integer represents the same cell identity across time.

**A note about uniqueness of keys:** ROI numbers are only unique within an fov (so rois are keyed by `(fov, roi)`), and cell_ids are only unique within an roi (so cell_ids are keyed by `(fov, roi, cell_id)`).

In [None]:
def track_mother_cell(df):
    # groupby rois
    mothers = df.xs(IDX[:, :, :, 1], drop_level=False)
    lengths = mothers["axis_major_length"].values
    # first cell ID is 1, (0 is used as marker of non-tracked cell segment)
    mother_cell_ids = np.concatenate(([1], 1 + np.cumsum(lengths[1:] < lengths[:-1])))
    cell_ids = np.zeros(len(df), dtype=np.uint64)
    cell_ids[df.index.get_locs(IDX[:, :, :, 1])] = mother_cell_ids
    return df.assign(cell_id=cell_ids)

In [None]:
%%time
all_tracked = all_measurements.groupby(["fov", "roi"], group_keys=False).progress_apply(
    track_mother_cell
)

Here we plot cell lengths (`axis_major_length`) colored by `cell_id`. You can see that each time the cell shrinks, it means the cell has divided and so is assigned a new `cell_id`.

In [None]:
all_tracked.xs(IDX[:, :, 3000, 1]).hvplot.scatter(
    "t", "axis_major_length", by="cell_id", cmap="Category20", legend=False
)

You can see that each colored sequence of points looks approximately linear. (Given that cells grow exponentially, we actually fit a line to `log(axis_major_length)`.) We thus fit a separate line segment (using ordinary least squares) to each `cell_id`, resulting in a dataframe with y-intercept (`alpha`), slope/growth rate parameter (`beta`), and fit quality (`r2`).

In this particular case of growth rate estimation using OLS, this could almost certainly be sped up a lot by batching and vectorizing OLS fits and/or using `numpy.linalg.lstsq` (which calls LAPACK). I keep the slow implementation here because this is representative of the kind of custom computations we want to be able to run.

In [None]:
def ols(x, y):
    num_obs = len(x)
    x_bar = x.sum() / num_obs
    y_bar = y.sum() / num_obs
    beta = (num_obs * (x * y).sum() - x.sum() * y.sum()) / (
        num_obs * (x**2).sum() - x.sum()**2
    )
    alpha = y_bar - beta * x_bar
    y_hat = alpha + beta * x
    r2 = np.sum((y_hat - y_bar) ** 2) / np.sum((y - y_bar) ** 2)
    return alpha, beta, r2


def lineage_growth_rate(df, min_obs=3):
    if len(df) < min_obs:
        return
    ts = df.index.get_level_values("t").values
    # TODO: not necessary, but makes alpha (y-intercept) comparable between lineages
    ts -= ts.min()
    log_length = np.log(df["axis_major_length"].values)
    new_df = pd.DataFrame(
        np.repeat([ols(ts, log_length)], len(df), axis=0),
        columns=["alpha", "beta", "r2"],
        index=df.index,
    )
    return new_df.assign(cell_id=df["cell_id"]).set_index("cell_id", append=True)

In [None]:
%%time
all_growth_rates = all_tracked[all_tracked["cell_id"] != 0].groupby(
    ["fov", "roi", "cell_id"], group_keys=False
).progress_apply(lineage_growth_rate)

Here we check that our $R^2$ values are close to 1.

In [None]:
all_growth_rates

In [None]:
all_growth_rates["r2"].hvplot.hist()

# 2D Heatmap

In [None]:
import hvplot.xarray
import xarray as xr

In [None]:
%%time
observable = "beta"
num_bins = 100
#measurements_subset = all_growth_rates[all_growth_rates["r2"] > 0.9]
measurements_subset = measurements_subset.reset_index()[["t", "alpha", "beta", "r2"]]
bins = np.linspace(
    measurements_subset[observable].min(),
    measurements_subset[observable].max(),
    num_bins,
)
heatmap = measurements_subset.groupby(["t"]).apply(
    lambda x: pd.Series(np.histogram(x[observable], bins=bins)[0], index=bins[:-1])
)
heatmap.columns.name = observable
heatmap = xr.DataArray(heatmap.T)

Using the exact same 2D heatmap code you've seen before, we can plot. This is the kind of plot we want to see update in real-time as new timepoints roll in.

In [None]:
heatmap.hvplot.quadmesh(
    cmap="blues",
    # logy=True,
    logz=True,
    # clim=(1, 1e4),
)