In [None]:
!pip install -r napari_interactive_example_requirements.txt

# Initialize napari

In [None]:
import napari
import numpy as np
import pandas as pd

from IPython.display import display
from skimage.io import imread
from skimage.measure import regionprops_table
from laptrack import LapTrack
from laptrack.data_conversion import convert_dataframe_to_coords, convert_tree_to_dataframe

In [None]:
images = imread("interactive_example_data/demo_image.tif")
labels = np.load("interactive_example_data/labels3.npy")

In [None]:
viewer = napari.Viewer()
viewer.add_image(images, name="images")
viewer.add_labels(labels, name="labels")

# Calculate properties of the segmentation

In [None]:
def calc_frame_regionprops(labels):  # noqa: E302
    dfs = []
    for frame in range(labels.shape[0]):
        df = pd.DataFrame(
            regionprops_table(labels[frame], properties=["label", "area", "centroid"])
        )
        df["frame"] = frame
        dfs.append(df)
    return pd.concat(dfs)
regionprops_df = calc_frame_regionprops(labels)
display(regionprops_df)

# Tracking by LapTrack

## Creating data

In [None]:
_coords = convert_dataframe_to_coords(
    regionprops_df, ["centroid-0", "centroid-1", "label"]
)
coords = [c[:, :-1] for c in _coords]
coord_labels = [c[:, -1] for c in _coords]

## Tracking

In [None]:
lt = LapTrack(splitting_cost_cutoff=20**2)
tree = lt.predict(coords)
tracked_df, _, _ = convert_tree_to_dataframe(tree)

## Adding the tracked data to the viewer

In [None]:
_regionprops_df = regionprops_df.set_index(["frame", "label"])
for (frame, index), row in tracked_df.iterrows():
    label = coord_labels[frame][index]
    _regionprops_df.loc[(frame, label), "track_id"] = row["track_id"] 
track_label_image = np.zeros_like(labels)
for (frame, label), row in _regionprops_df.iterrows():
    track_label_image[frame][labels[frame] == label] = row["track_id"] + 1

In [None]:
viewer.layers["labels"].visible = False
viewer.add_labels(track_label_image)

# Manual correction

add points for the cells validated maually (emurated)

In [None]:
manual_corrected = np.load("interactive_example_data/manual_corrected.npy")
viewer.add_points(manual_corrected, name="manually_validated_tracks")

In [None]:
manual_corrected = viewer.layers["manually_validated_tracks"].data.astype(np.int16)
# you can also redraw the labels
new_labels = viewer.layers["track_label_image"].data
# get label values at the placed points
validated_track_labels = new_labels[tuple(manual_corrected.T)]
validated_frames = manual_corrected[:, 0]

In [None]:
validated_points=np.array(list(zip(validated_frames, validated_track_labels)))
validated_points=validated_points[np.argsort(validated_points[:,0])]
validated_ind_pairs=[
    set(((frame1,label1),(frame2,label2))) for ((frame1,label1),(frame2,label2)) 
        in zip(validated_points[:-1],validated_points[1:])
]
validated_ind_pairs

# Second tracking preserving manually corrected data

In [None]:
new_regionprops_df = calc_frame_regionprops(new_labels).set_index(["frame", "label"]).reset_index()
new_coords = convert_dataframe_to_coords(
    new_regionprops_df,
    ["centroid-0", "centroid-1", "frame", "label"],
)
coord_labels = [c[:, -1] for c in new_coords]
new_regionprops_df.loc[[33,63,95]]

pairs of the coordinates of the validated points

In [None]:
def manual_corrected_aware_dist(c1,c2):
    *coords1, frame1, label1 = c1
    *coords2, frame2, label2 = c2
    pair={( int(frame1),int(label1),), (int(frame2),int(label2))}
    if pair in validated_ind_pairs:
        return 0.0001
    else:
        return np.linalg.norm(np.array(coords1)-np.array(coords2))**2

lt = LapTrack(splitting_cost_cutoff=20**2,
              track_dist_metric=manual_corrected_aware_dist,
              splitting_dist_metric=manual_corrected_aware_dist,)
new_tree = lt.predict(new_coords)
new_tracked_df, _, _ = convert_tree_to_dataframe(new_tree)

In [None]:
new_track_label_image = np.zeros_like(new_labels)
for (frame, ind), row in new_tracked_df.iterrows():
    label=new_labels[frame] == coord_labels[frame][ind]
    new_track_label_image[frame][label] = row["tree_id"] + 1
viewer.layers["track_label_image"].visible = False
viewer.add_labels(new_track_label_image)