In [3]:
from pathlib import Path
import cv2
import sys
import numpy as np
from skimage import feature, measure
import livecell_tracker
from livecell_tracker.segment import datasets
from livecell_tracker.segment.datasets import LiveCellImageDataset
import livecell_tracker.segment
from livecell_tracker import segment
import livecell_tracker.core.utils
from tqdm import tqdm
seg_path = r"""./notebook_results/segmentation_results/detectron_model3/restart_day0_Group 1_wellA1_RI_MIP_stitched"""
# seg_path = r"""./notebook_results/segmentation_results/detectron_model3/restart_day2_Group 1_wellA1_RI_MIP_stitched"""
label_seg_imgs = LiveCellImageDataset(seg_path, ext="png", num_imgs=3)
dir_path = Path(
    r"../cxa-data/june_2022_data_8bit_png/restart_day0_Group 1_wellA1_RI_MIP_stitched"
)
# dir_path = Path(
#     r"../cxa-data/june_2022_data_8bit_png/restart_day2_Group 1_wellA1_RI_MIP_stitched"
# )

raw_imgs = LiveCellImageDataset(dir_path, ext="png", num_imgs=3)

3 png img file paths loaded: 
3 png img file paths loaded: 


In [9]:
def gen_SORT_detections_input_from_label_mask(label_mask):
    """
        generate detections for SORT tracker. detections: [x1, y1, x2, y2, score]
    Parameters
    ----------
    label_mask :
        an image
    Returns
    -------
    A list of (x1, y1, x2, y2, score]) for each object detected
    """    
    regionprops_skimge = measure.regionprops(label_mask)
    boxes = livecell_tracker.core.utils.get_bbox_from_regionprops(regionprops_skimge)
    detections = np.array([list(bbox) + [1] for bbox in boxes])
    return detections


import livecell_tracker.track.sort_tracker
MAX_AGE=5
MIN_HITS=1
tracker = livecell_tracker.track.sort_tracker.Sort(max_age=MAX_AGE, min_hits=MIN_HITS)


In [10]:
def track_by_label_masks():
    all_track_bbs = []
    for i, img in enumerate(tqdm(label_seg_imgs)):
        detections = gen_SORT_detections_input_from_label_mask(img)
        track_bbs_ids = tracker.update(detections)
        all_track_bbs.append(track_bbs_ids)
        # for testing
        # if i > 5:
        #     break

def track_by_contour_list():
    all_track_bbs = []
    for i, img in enumerate(tqdm(label_seg_imgs)):
        detections = gen_SORT_detections_input_from_label_mask(img)
        track_bbs_ids = tracker.update(detections)
        all_track_bbs.append(track_bbs_ids)
        # for testing
        # if i > 5:
        #     break

In [17]:
import json
def get_bbox_from_contour(contour: list):
    contour = np.array(contour)
    return np.array([contour[:, 0].min(), contour[:, 1].min(), contour[:, 0].max(), contour[:, 1].max()])


def gen_SORT_detections_input_from_contours(contours):
    """
    generate detections for SORT tracker. detections: [x1, y1, x2, y2, score]
    ----------
    label_mask :
        an image
    Returns
    -------
    A list of (x1, y1, x2, y2, score]) for each object detected
    """   
    boxes = [get_bbox_from_contour(contour) for contour in contours]
    detections = np.array([list(bbox) + [1] for bbox in boxes])
    return detections

segmentation_result_json_path = r"./notebook_results/seg_test_tmp_day0/segmentation_results.json"
segmentation_results = json.load(open(segmentation_result_json_path, "r"))
segmentation_results

all_track_bbs = []
for idx, img in enumerate(raw_imgs):
    print("matching image path:", raw_imgs.get_img_path(idx))
    img_path = raw_imgs.get_img_path(idx)
    # TODO: fix in the future only for windows...Somehow json lib saved double slashes
    img_path.replace("//", "////") 
    contours = segmentation_results[raw_imgs.get_img_path(idx)]["contours"]
    detections = gen_SORT_detections_input_from_contours(contours)
    track_bbs_ids = tracker.update(detections)
    all_track_bbs.append(track_bbs_ids)

matching image path: ../cxa-data/june_2022_data_8bit_png/restart_day0_Group 1_wellA1_RI_MIP_stitched/T001.png
matching image path: ../cxa-data/june_2022_data_8bit_png/restart_day0_Group 1_wellA1_RI_MIP_stitched/T002.png
matching image path: ../cxa-data/june_2022_data_8bit_png/restart_day0_Group 1_wellA1_RI_MIP_stitched/T003.png


convert to int

In [18]:
all_track_bbs = [track_bbs.astype(np.int32).tolist() for track_bbs in all_track_bbs]
all_track_bbs[0]

[[0, 1399, 67, 1577, 29],
 [396, 288, 456, 355, 27],
 [1385, 1308, 1443, 1363, 26],
 [777, 2162, 851, 2235, 25],
 [1394, 1276, 1449, 1331, 24],
 [1826, 1586, 1908, 1618, 23],
 [137, 1334, 225, 1391, 22],
 [284, 269, 341, 357, 21],
 [59, 1265, 148, 1328, 20],
 [90, 1416, 164, 1473, 19],
 [1847, 1366, 1906, 1419, 18],
 [16, 1249, 302, 1402, 17],
 [268, 463, 353, 531, 16],
 [6, 1232, 212, 1351, 15],
 [757, 2236, 989, 2376, 14],
 [1013, 2321, 1102, 2383, 13],
 [103, 1295, 321, 1405, 12],
 [1724, 1554, 2005, 1690, 11],
 [991, 2282, 1138, 2400, 10],
 [739, 2116, 873, 2257, 9],
 [1780, 1493, 2033, 1576, 7],
 [72, 423, 605, 693, 6],
 [357, 186, 575, 410, 5],
 [1787, 1255, 2014, 1478, 4],
 [45, 1394, 282, 1541, 3],
 [1295, 1159, 1571, 1479, 2],
 [151, 189, 363, 446, 1]]

### Save track bbox results to json for later development

In [None]:
dest_track_bbs_path = "detectron_model3_all_track_bbs-restart_day0_Group 1_wellA1_RI_MIP_stitched_by_contours.json"
import json
with open(dest_track_bbs_path, "w+") as out_f:
    json.dump(all_track_bbs, out_f)


## Visualize Track results

In [None]:
from pathlib import Path
import cv2
import sys
import numpy as np
from skimage import feature, measure
import livecell_tracker
from livecell_tracker.segment import datasets
import livecell_tracker.segment
from livecell_tracker import segment
import livecell_tracker.core.utils
from tqdm import tqdm
seg_path = r"""./notebooks_results/segmentation_results/detectron_model3/restart_day0_Group 1_wellA1_RI_MIP_stitched"""
label_seg_imgs = segment.datasets.LiveCellImageDataset(seg_path, ext="png")
dir_path = Path(
    r"../cxa-data/june_2022_data_8bit_png/restart_day0_Group 1_wellA1_RI_MIP_stitched"
)
raw_imgs = segment.datasets.LiveCellImageDataset(dir_path, ext="png")

### load bbox paths

In [None]:
import json
MAX_AGE, MIN_HITS = 1, 1
_track_bbs_path = "../cxa-data/test_data/sort_track/max_age-{}_min_hit-{}_detectron_model3_all_track_bbs-restart_day0_Group 1_wellA1_RI_MIP_stitched.json".format(MAX_AGE, MIN_HITS)
print("loading: ", _track_bbs_path)
with open(_track_bbs_path, "r") as in_f:
    json_in = json.load(in_f)
type(json_in[0][0][0])

In [None]:
all_track_bbs = [np.array(track_bbs, dtype=int) for track_bbs in json_in]

In [None]:
all_track_bbs[0].shape

### Manually check bboxes  
 - be careful: coordinates of numpy/skimage/cv2/pixel/rowCol/bbox

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
idx = 2
bboxes = all_track_bbs[idx]
fig, ax = plt.subplots()
ax.imshow(raw_imgs[idx])
for bbox in bboxes:
    # Create a Rectangle patch
    rect = patches.Rectangle((bbox[1], bbox[0]),  (bbox[3] - bbox[1]), (bbox[2] - bbox[0]), linewidth=1, edgecolor='r', facecolor='none')
    # Add the patch to the Axes
    ax.add_patch(rect)
# plt.show()


In [None]:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic, single_cell

id_to_sc_trajs = {}
for timeframe, objects in enumerate(all_track_bbs):
    for obj in objects:
        track_id = obj[-1]
        if not (track_id in id_to_sc_trajs):
            new_traj = SingleCellTrajectory(raw_imgs, track_id=track_id)
            id_to_sc_trajs[track_id] = new_traj
        # print("obj: ", obj)
        sc = SingleCellStatic(
            timeframe, bbox=obj[:4], img_dataset=raw_imgs
        )  # final column is track_id, ignore as we only need bbox here
        _traj = id_to_sc_trajs[track_id]
        _traj.add_timeframe_data(timeframe, sc)


### length distribution

In [None]:
%matplotlib inline
import seaborn as sns
all_traj_lengths = np.array([_traj.get_timeframe_span_length() for _traj in id_to_sc_trajs.values()])
sns.histplot(all_traj_lengths, bins=100)
plt.title("max_age={}, min_hits={}".format(MAX_AGE, MIN_HITS))
plt.ylabel("Count")
plt.xlabel("Traj length")
plt.show()

In [None]:
(all_traj_lengths > 10).sum(), (all_traj_lengths > 30).sum(), (all_traj_lengths > 50).sum()