In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
import glob
from pathlib import Path
from PIL import Image, ImageSequence
from tqdm import tqdm
import os
import os.path
from livecell_tracker import segment
from livecell_tracker import core
from livecell_tracker.core import datasets
from livecell_tracker.core.datasets import LiveCellImageDataset
from skimage import measure
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic
import detectron2
from detectron2.utils.logger import setup_logger

setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
import cv2

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from livecell_tracker.segment.detectron_utils import gen_cfg

from livecell_tracker.segment.detectron_utils import (
    segment_detectron_wrapper,
    segment_images_by_detectron,
    convert_detectron_instance_pred_masks_to_binary_masks,
    convert_detectron_instances_to_label_masks,
)
from livecell_tracker.segment.detectron_utils import (
    convert_detectron_instance_pred_masks_to_binary_masks,
    convert_detectron_instances_to_label_masks,
    segment_images_by_detectron,
    segment_single_img_by_detectron_wrapper,
)


dataset_dir_path = Path(
    "../datasets/test_data_STAV-A549/DIC_data"
)

mask_dataset_path = Path("../datasets/test_data_STAV-A549/mask_data")
mask_dataset = LiveCellImageDataset(mask_dataset_path, ext="png")
dic_dataset = LiveCellImageDataset(dataset_dir_path, ext="tif")

In [None]:
dataset_dir_path = Path(
    "../datasets/EBSS_Starvation/tif_STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21/XY16/"
)

mask_dataset_path = Path("../datasets/EBSS_Starvation/tif_STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21/out/XY16/seg")

mask_dataset = LiveCellImageDataset(mask_dataset_path, ext="png")
time2url = sorted(glob.glob(str((Path(dataset_dir_path) / Path("*_DIC.tif")))))
time2url = {i: path for i, path in enumerate(time2url)}
dic_dataset = LiveCellImageDataset(time2url=time2url, ext="tif")

Convert label masks to single objects

In [None]:
from livecell_tracker.core.io_sc import process_scs_from_label_mask, prep_scs_from_mask_dataset

single_cells = prep_scs_from_mask_dataset(mask_dataset, dic_dataset)

In [None]:
# for testing
# single_cells = single_cells[:10]

In [None]:
len(single_cells)

Check the cells visually

In [None]:
# for time in overlap_map_by_time:
#     overlap_map = overlap_map_by_time[time]
#     for sc_tmp1, sc_tmp2 in overlap_map:
#         if sc_tmp1 == sc_tmp2:
#             continue
#         if overlap_map[(sc_tmp1, sc_tmp2)][0] > 0:
#             print(sc_tmp1.timeframe, sc_tmp2.timeframe, overlap_map[(sc_tmp1, sc_tmp2)])
#             fig, axes = plt.subplots(1, 6, figsize=(15, 5))
#             padding=50
#             sc_tmp1.show_contour_mask(crop=False, ax = axes[0])
#             sc_tmp2.show_contour_mask(crop=False, ax = axes[1])
#             sc_tmp1.show(crop=True, ax = axes[2], padding=padding)
#             sc_tmp2.show(crop=True, ax = axes[3], padding=padding)
#             sc_tmp1.show_contour_mask(crop=True, ax = axes[4], padding=padding)
#             sc_tmp2.show_contour_mask(crop=True, ax = axes[5], padding=padding)
#             plt.show()


```
sc.datasets["img"]
sc.datasets["mask"]
sc.datasets["label"]
sc.datasets["TRITC"]
```

In [None]:
from livecell_tracker.segment.utils import match_mask_labels_by_iou

## Track cells

In [None]:
from typing import List
from livecell_tracker.track.sort_tracker_utils import (
    gen_SORT_detections_input_from_contours,
    update_traj_collection_by_SORT_tracker_detection,
    track_SORT_bbox_from_contours,
    track_SORT_bbox_from_scs
)
traj_collection = track_SORT_bbox_from_scs(single_cells, dic_dataset, mask_dataset=mask_dataset, max_age=1, min_hits=1)

Within the same trajectory, check if there is any multiple mapping issue within an interval

In [None]:
from livecell_tracker.segment.utils import filter_labels_match_map
def filter_labels_match_map(gt2seg_iou__map, iou_threshold):
    label_map = {}
    for label_1 in gt2seg_iou__map:
        label_map[label_1] = {}
        for score_info in gt2seg_iou__map[label_1]:
            if score_info["iou"] > iou_threshold:
                label_map[label_1][score_info["seg_label"]] = {"iou": score_info["iou"]}
    return label_map

In [None]:
from livecell_tracker.model_zoo.segmentation.sc_correction import CorrectSegNet
# ckpt = r"/home/ken67/LiveCellTracker-dev/notebooks/lightning_logs/version_real-02/checkpoints/epoch=3720-test_loss=0.0085.ckpt"
# ckpt = r"/home/ken67/LiveCellTracker-dev/notebooks/lightning_logs/version_802/checkpoints/epoch=2570-test_out_matched_num_gt_iou_0.5_percent_real_underseg_cases=0.8548.ckpt"
# ckpt = r"/home/ken67/LiveCellTracker-dev/notebooks/lightning_logs/version_v10_02/checkpoints/epoch=2999-global_step=0.ckpt"
ckpt = r"/home/ken67/LiveCellTracker-dev/notebooks/lightning_logs/version_v10_02/checkpoints/epoch=5999-global_step=0.ckpt"

model = CorrectSegNet.load_from_checkpoint(ckpt)
model = model.cuda()
model = model.eval()

Psuedocode  

Start at each timepoint  
    from t to t + window_size  
    count if segmentation at time t conforms with the majority of the segmentation results  
    try using correction CNN to correct if not
    

In [None]:
from livecell_tracker.segment.ou_viz import viz_ou_sc_outputs
from livecell_tracker.core.parallel import parallelize
from torchvision import transforms

trajectory = traj_collection.get_trajectory(5)
input_transforms = transforms.Compose(
        [
            transforms.Resize(size=(412, 412)),
        ]
)
fig_out_dir = Path("./tmp_csn_temporal_correct")
os.makedirs(fig_out_dir, exist_ok=True)

padding_pixels = 50
one_object = True
out_threshold=4
remove_bg = False

def consensus_trajectory(trajectory: SingleCellTrajectory, sliding_window=3, iou_threshold=0.3):
    failed_consensus_track_times = []
    conflict_track_id_and_time_pairs = []
    track_id = trajectory.track_id

    for time, pivot_sc in trajectory:
        _consensus = []
        cur_bbox = pivot_sc.get_bbox()
        cur_label_mask = pivot_sc.get_sc_mask() # a mask containing one label
        cur_time = time 
        for i in range(sliding_window):
            # next_time = time + i + 1
            cur_time = trajectory.next_time(cur_time)
            if (cur_time is None):
                break
            next_sc = trajectory[cur_time]
            next_label_mask = next_sc.get_mask_crop(bbox=cur_bbox, dtype=int)
            _, all_gt2seg_iou__map = match_mask_labels_by_iou(next_label_mask, cur_label_mask, return_all=True)
            label_map =  filter_labels_match_map(all_gt2seg_iou__map, iou_threshold=iou_threshold)
            assert len(label_map) == 1, "only one label should be matched"
            label = list(label_map)[0]
            is_uniform_map = len(label_map[label]) == 1
            _consensus.append(is_uniform_map)
            if not is_uniform_map:
                conflict_track_id_and_time_pairs.append((track_id, time, cur_time))
        if len(_consensus) == 0:
            continue
        is_majority_consenus = sum(_consensus) > (len(_consensus) / 2 - 1) # -1 for including itself
        # print("is_majority_consenus:", is_majority_consenus)
        if not is_majority_consenus:
            failed_consensus_track_times.append((trajectory.track_id, time))
    return failed_consensus_track_times, conflict_track_id_and_time_pairs

inputs = []
iou_threshold = 0.3
sliding_window = 10
for track_id, trajectory in traj_collection:
    inputs.append((trajectory, sliding_window, iou_threshold))

results = parallelize(consensus_trajectory, inputs)

In [None]:
failed_consensus_track_times = [item for x in [result[0] for result in results] for item in x ]
conflict_track_id_and_time_pairs = [item for x in [result[1] for result in results] for item in x ]

In [None]:
len(traj_collection), len(failed_consensus_track_times), len(conflict_track_id_and_time_pairs)

In [None]:
total_track_time_pairs = 0
for _, trajectory in traj_collection:
    total_track_time_pairs += len(trajectory)
total_track_time_pairs

In [None]:
failed_consensus_track_times[:5]

In [None]:
conflict_track_id_and_time_pairs[:2]

In [None]:
trajectory.timeframe_set

In [None]:
conflict_track_id_and_time_pairs[0][0]
trajectory = traj_collection.get_trajectory(track_id)
trajectory.timeframe_set

In [None]:
conflict_track_idx = 120
track_id = conflict_track_id_and_time_pairs[conflict_track_idx][0]
time = conflict_track_id_and_time_pairs[conflict_track_idx][1]
cur_time = conflict_track_id_and_time_pairs[conflict_track_idx][2]

trajectory = traj_collection.get_trajectory(track_id)

cur_sc = trajectory[time]
next_sc = trajectory[cur_time]
cur_bbox = cur_sc.get_bbox()
cur_label_mask = cur_sc.get_sc_mask()
fig_out_dir = Path("./tmp_csn_temporal_correct")
os.makedirs(fig_out_dir, exist_ok=True)

print(">" * 80)
print("track id:", track_id)
print("time, next_time:", time, cur_time)
print("current sc:")
# viz cells
cur_sc.show_panel(padding=50)
# plt.savefig(fig_out_dir / f"{track_id}_{time}_{next_time}_first.png")
viz_ou_sc_outputs(cur_sc, model, transforms=input_transforms, save_path=fig_out_dir / f"{track_id}_{time}_{cur_time}_first_csn.png", show=True, remove_bg=remove_bg,padding_pixels=padding_pixels,out_threshold=out_threshold)
for i in range(sliding_window):
    cur_time = time + i + 1
    if cur_time not in trajectory.timeframe_set:
        print("next time not in trajectory.timeframe_set")
        continue
    next_sc = trajectory[cur_time]
    next_label_mask = next_sc.get_mask_crop(bbox=cur_bbox, dtype=int)

    print("sc at time:", cur_time)
    next_sc.show_panel(padding=50)
    # plt.savefig(fig_out_dir / f"{track_id}_{time}_{next_time}_second.png")
    
    viz_ou_sc_outputs(next_sc, model, transforms=input_transforms, save_path=fig_out_dir / f"{track_id}_{time}_{cur_time}_second_csn.png", show=True, remove_bg=remove_bg, padding_pixels=padding_pixels, out_threshold=out_threshold)