# 2. Track objects in simulated data using ARCOS.px

In [9]:
import os
import glob
import numpy as np
import pandas as pd
from arcos4py.tools import track_events_image
from arcospx import utils

from skimage.measure import regionprops, regionprops_table
from skimage.util import map_array

# TIFF stacks I/O
from tifffile import imread, imwrite

import napari
from napari_timestamper import render_as_rgb

In [2]:
IMG_NAPARI = True  # Whether to show images in napari

def remap_segmentation(
    df: pd.DataFrame,
    segmentation: list,
    timepoint_column: str = "timepoint",
    label_column: str = "label",
    measure_column: str = "lineage",
) -> list:
    tracked_numpy = df[[timepoint_column, label_column, measure_column]].sort_values(timepoint_column).to_numpy()
    grouped_numpy = np.split(tracked_numpy, np.unique(tracked_numpy[:, 0], return_index=True)[1][1:])
    ratio_remapped = []
    for img, grp in zip(segmentation, grouped_numpy):
        img_copy = map_array(img, grp[:, 1], grp[:, 2])
        ratio_remapped.append(img_copy)
    return ratio_remapped

def track_events(bin_mask, min_clustersize=3, eps=3, downsample=1, n_prev=3, stability_threshold=1):
    # Track events
    ndarr_events, obj_lin = track_events_image(
        bin_mask,
        clustering_method="dbscan",
        min_clustersize=min_clustersize,
        eps=eps,
        downsample=downsample,
        n_prev=n_prev,
        show_progress=False,
        allow_merges=True,
        allow_splits=True,
        remove_small_clusters=True, 
        stability_threshold=stability_threshold
    )

    # Filter out lineages shorter than 3 and reflect the filtering on the tracked events
    obj_lin_filt = obj_lin.filter(criteria='lineage_duration', min_value=3)
    ndarr_events_filt = obj_lin_filt.reflect(ndarr_events)

    # Remap the lineage IDs to the tracked events
    data_tmp = []

    # create a lineage/label df
    for i, frame_data in enumerate(ndarr_events_filt):
        for event in regionprops(frame_data):
            if event.label == 0:
                continue
            data_tmp.append(
                {
                    "timepoint": i,
                    "event": event.label,
                    "lineage": obj_lin_filt.nodes[event.label].lineage_id,
                    "x": event.centroid[1],
                    "y": event.centroid[0],
                    "bbox-0": event.bbox[0],
                    "bbox-1": event.bbox[1],
                    "bbox-2": event.bbox[2],
                    "bbox-3": event.bbox[3],
                }
            )

    # remap the segmentation
    l_remap = remap_segmentation(
        pd.DataFrame(data_tmp), 
        ndarr_events_filt, 
        timepoint_column="timepoint", 
        label_column="event", 
        measure_column="lineage"
    )
    
    return ndarr_events_filt, obj_lin_filt, l_remap

## Track objects in a single simulation

In [3]:
# Load the data
core_out_dir = "../../data/1_wave_split_merge_sim/output-data"
sim_id = "sim_seed017"
data_dir = os.path.join(core_out_dir, sim_id, "GT")
pred_dir = os.path.join(core_out_dir, sim_id, "PRED")
lin_mask = np.load(os.path.join(data_dir, "lineage_masks.npy"))
obj_mask = np.load(os.path.join(data_dir, "object_masks.npy"))
bin_mask = np.load(os.path.join(data_dir, "binary_masks.npy"))

In [4]:
# Open napari image viewer and add the two time lapses
if IMG_NAPARI:
    if napari.current_viewer() is None:
        viewer = napari.Viewer()

    viewer.add_image(bin_mask, name='bin_mask', blending='additive', contrast_limits=[0, 1])
    viewer.add_labels(obj_mask, name='obj_mask', blending='additive')
    viewer.add_labels(lin_mask, name='lin_mask', blending='additive')
    viewer = napari.current_viewer()

In [12]:
stab_thr = 16
ndarr_events_filt, obj_lin_filt, lin_remap = track_events(bin_mask, 
                                                          min_clustersize=3, 
                                                          eps=3, 
                                                          downsample=1, 
                                                          n_prev=3, 
                                                          stability_threshold=stab_thr)

data, properties, graph = utils.tracker_to_napari_tracks(obj_lin_filt, label_stack=ndarr_events_filt)

In [13]:
# Show events coloured by lineage IDs in napari
if IMG_NAPARI:
    layer_events_id = viewer.add_labels(ndarr_events_filt, 
                                        name="events_ids", 
                                        blending="translucent", opacity=1)
        
    layer_lin_id = viewer.add_labels(np.array(lin_remap), 
                                         name="lineage_ids", 
                                         blending="translucent", opacity=1)
    layer_lin_id.new_colormap(seed=17)
    layer_lin_id.contour = 1
    
    viewer.add_tracks(data, properties=properties, graph=graph, name="tracks")

In [None]:
# Create output directory
os.makedirs(pred_dir, exist_ok=True)
np.save(os.path.join(pred_dir, "lineage_masks.npy"), np.array(lin_remap))

In [208]:
# Save current napari view as RGB tif stack
core_video_dir = core_out_dir.replace("output-data", "output-video")
out_video_dir = os.path.join(core_video_dir, sim_id)
os.makedirs(out_video_dir, exist_ok=True)

if IMG_NAPARI:
    rgb_stack = render_as_rgb(viewer, 0, 4)
    # imwrite(f"{out_video_dir}/napariview_stabthr{stab_thr}.tif", rgb_stack)
    imwrite(f"{out_video_dir}/napariview_obj.tif", rgb_stack)

## Process all simulations in the `output-data` folder

Load `binary_masks.npy`, process with ARCOS.px and save:

- `events_masks_pred.npy` with predicted events,
- `lineage_masks_pred.npy` with lineage predictions.

In [None]:
core_out_dir = "../../data/1_wave_split_merge_sim/output-data"

# Loop over all sim folders
dirs = glob.glob(os.path.join(core_out_dir, "sim_seed*/GT"))
dirs.sort()

for dir in dirs:
    print(dir)
    bin_mask = np.load(os.path.join(dir, "binary_masks.npy"))
    dir_out = dir.replace("GT", "PRED")
    os.makedirs(dir_out, exist_ok=True)
    
    # Try different stability thresholds
    for stab_thr in [1, 2, 4, 8, 16]:
        ndarr_events_filt, obj_lin_filt, lin_remap = track_events(bin_mask, min_clustersize=3, eps=3, downsample=1, n_prev=3, stability_threshold=stab_thr)
        
        np.save(os.path.join(dir_out, f"events_masks_stabthr{stab_thr}.npy"), ndarr_events_filt)
        np.save(os.path.join(dir_out, f"lineage_masks_stabthr{stab_thr}.npy"), np.array(lin_remap))

../../data/1_split_merge_sim/output-data/sim_seed007/GT
../../data/1_split_merge_sim/output-data/sim_seed011/GT
../../data/1_split_merge_sim/output-data/sim_seed013/GT
../../data/1_split_merge_sim/output-data/sim_seed017/GT
../../data/1_split_merge_sim/output-data/sim_seed019/GT
../../data/1_split_merge_sim/output-data/sim_seed023/GT
../../data/1_split_merge_sim/output-data/sim_seed031/GT
../../data/1_split_merge_sim/output-data/sim_seed037/GT
../../data/1_split_merge_sim/output-data/sim_seed041/GT
../../data/1_split_merge_sim/output-data/sim_seed042/GT
../../data/1_split_merge_sim/output-data/sim_seed043/GT
../../data/1_split_merge_sim/output-data/sim_seed047/GT
../../data/1_split_merge_sim/output-data/sim_seed101/GT
../../data/1_split_merge_sim/output-data/sim_seed123/GT
../../data/1_split_merge_sim/output-data/sim_seed202/GT
../../data/1_split_merge_sim/output-data/sim_seed303/GT
../../data/1_split_merge_sim/output-data/sim_seed404/GT
../../data/1_split_merge_sim/output-data/sim_see

# Next step
Run `tcompare_sim_track.ipynb` to calculate tracking metrics