In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
import glob
from pathlib import Path
from PIL import Image, ImageSequence
from tqdm import tqdm
import os
import os.path
from livecell_tracker import segment
from livecell_tracker import core
from livecell_tracker.core import datasets
from livecell_tracker.core.datasets import LiveCellImageDataset
from skimage import measure
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic
import detectron2
from detectron2.utils.logger import setup_logger

setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
import cv2

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from livecell_tracker.segment.detectron_utils import gen_cfg

from livecell_tracker.segment.detectron_utils import (
    segment_detectron_wrapper,
    segment_images_by_detectron,
    convert_detectron_instance_pred_masks_to_binary_masks,
    convert_detectron_instances_to_label_masks,
)
from livecell_tracker.segment.detectron_utils import (
    convert_detectron_instance_pred_masks_to_binary_masks,
    convert_detectron_instances_to_label_masks,
    segment_images_by_detectron,
    segment_single_img_by_detectron_wrapper,
)


dataset_dir_path = Path(
    "../datasets/test_data_STAV-A549/DIC_data"
)

mask_dataset_path = Path("../datasets/test_data_STAV-A549/mask_data")
mask_dataset = LiveCellImageDataset(mask_dataset_path, ext="png")
dic_dataset = LiveCellImageDataset(dataset_dir_path, ext="tif")

In [None]:
dataset_dir_path = Path(
    "../datasets/EBSS_Starvation/tif_STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21/XY16/"
)

mask_dataset_path = Path("../datasets/EBSS_Starvation/tif_STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21/out/XY16/seg")

mask_dataset = LiveCellImageDataset(mask_dataset_path, ext="png")
time2url = sorted(glob.glob(str((Path(dataset_dir_path) / Path("*_DIC.tif")))))
time2url = {i: path for i, path in enumerate(time2url)}
dic_dataset = LiveCellImageDataset(time2url=time2url, ext="tif")

Convert label masks to single objects

In [None]:
from livecell_tracker.core.io_sc import process_scs_from_label_mask, prep_scs_from_mask_dataset

single_cells = prep_scs_from_mask_dataset(mask_dataset, dic_dataset)

In [None]:
# for testing
# single_cells = single_cells[:10]

In [None]:
len(single_cells)

In [None]:
single_cells_by_time = {}
for cell in single_cells:
    if cell.timeframe not in single_cells_by_time:
        single_cells_by_time[cell.timeframe] = []
    single_cells_by_time[cell.timeframe].append(cell)

In [None]:
for time in single_cells_by_time:
    print(time, len(single_cells_by_time[time]))

Visualize one single cell

In [None]:
sc = single_cells[0]

fig, axes = plt.subplots(1, 4, figsize=(10, 5))
sc.show(ax=axes[0])
sc.show_mask(ax=axes[1])
sc.show_contour_img(ax=axes[2])
sc.show_contour_mask(ax=axes[3])

Calculate overlap between two single cells

In [None]:
sc.show_mask(padding=200)

In [None]:
sc.show_contour_mask(padding=200)

In [None]:
sc1 = single_cells_by_time[0][0]
sc2 = single_cells_by_time[0][1]

def compute_overlap_bf(sc1, sc2):
    # calculate overlap
    img_shape = sc1.get_img().shape
    # TODO: add a helper function in single_cell to returna mask with only the current cell in it.
    mask1 = sc1.get_contour_mask(padding=np.max(img_shape)).astype(bool)
    mask2 = sc2.get_contour_mask(padding=np.max(img_shape)).astype(bool)

    overlap_area = np.logical_and(mask1, mask2).sum()
    iou = overlap_area / (mask1 | mask2).sum()
    return overlap_area, iou

def bbox_overlap(bbox1, bbox2):
    # calculate overlap
    x1_min, y1_min, x1_max, y1_max = bbox1
    x2_min, y2_min, x2_max, y2_max = bbox2
    x_overlap = max(0, min(x1_max, x2_max) - max(x1_min, x2_min))
    y_overlap = max(0, min(y1_max, y2_max) - max(y1_min, y2_min))
    overlap_area = x_overlap * y_overlap
    bbox1_area = (x1_max - x1_min) * (y1_max - y1_min)
    bbox2_area = (x2_max - x2_min) * (y2_max - y2_min)
    iou = overlap_area / (bbox1_area + bbox2_area - overlap_area)
    return overlap_area, iou

def compute_overlap(sc1: SingleCellStatic, sc2: SingleCellStatic):
    bbox1, bbox2 = sc1.get_bbox(), sc2.get_bbox()
    bbox_overlap_area, bbox_iou = bbox_overlap(bbox1, bbox2)
    if bbox_iou <= 0:
        return 0, 0

    merged_bbox = (min(bbox1[0], bbox2[0]), min(bbox1[1], bbox2[1]), max(bbox1[2], bbox2[2]), max(bbox1[3], bbox2[3]))
    # calculate overlap
    # TODO: add a helper function in single_cell to returna mask with only the current cell in it.
    mask1 = sc1.get_contour_mask(crop=False)[merged_bbox[0]:merged_bbox[2], merged_bbox[1]:merged_bbox[3]]
    mask2 = sc2.get_contour_mask(crop=False)[merged_bbox[0]:merged_bbox[2], merged_bbox[1]:merged_bbox[3]]
    overlap_area = np.logical_and(mask1, mask2).sum()
    iou = overlap_area / (mask1 | mask2).sum()
    return overlap_area, iou
t1, t2 = 0, 1

In [None]:
def test_compute_overlap():
    for sc in single_cells:
        for sc_tmp in single_cells:
            # compare two overlap algorithms
            overlap_area, iou = compute_overlap(sc, sc_tmp)
            overlap_area_bf, iou_bf = compute_overlap_bf(sc, sc_tmp)
            # print(overlap_area, overlap_area_bf)
            assert overlap_area == overlap_area_bf
            assert iou == iou_bf
# test_compute_overlap()

In [None]:
import tqdm
def compute_overlaps(sc_list1, sc_list2):
    overlap_map = {}
    for sc1 in tqdm.tqdm(sc_list1, desc="Computing overlaps"):
        for sc2 in sc_list2:
            overlap_area, iou = compute_overlap(sc1, sc2)
            overlap_map[(sc1, sc2)] = (overlap_area, iou)
    # parallel version
    return overlap_map

overlap_map_by_time = {}
times = set(sorted(list(single_cells_by_time.keys())))
for time in times:
    if time + 1 not in times:
        print(f"Time {time} is the last time point, skipping")
        continue
    overlap_map_by_time[time] = compute_overlaps(single_cells_by_time[time], single_cells_by_time[time + 1])

In [None]:
sc1, sc2 = single_cells_by_time[0][0], single_cells_by_time[1][0]
overlap_map_by_time[0][(sc1, sc2)]

In [None]:
len(single_cells_by_time[2])

Check the cells visually

In [None]:
# for time in overlap_map_by_time:
#     overlap_map = overlap_map_by_time[time]
#     for sc_tmp1, sc_tmp2 in overlap_map:
#         if sc_tmp1 == sc_tmp2:
#             continue
#         if overlap_map[(sc_tmp1, sc_tmp2)][0] > 0:
#             print(sc_tmp1.timeframe, sc_tmp2.timeframe, overlap_map[(sc_tmp1, sc_tmp2)])
#             fig, axes = plt.subplots(1, 6, figsize=(15, 5))
#             padding=50
#             sc_tmp1.show_contour_mask(crop=False, ax = axes[0])
#             sc_tmp2.show_contour_mask(crop=False, ax = axes[1])
#             sc_tmp1.show(crop=True, ax = axes[2], padding=padding)
#             sc_tmp2.show(crop=True, ax = axes[3], padding=padding)
#             sc_tmp1.show_contour_mask(crop=True, ax = axes[4], padding=padding)
#             sc_tmp2.show_contour_mask(crop=True, ax = axes[5], padding=padding)
#             plt.show()


In [None]:
np.unique(sc1.mask_dataset.get_img_by_time(0))

In [None]:
np.unique(sc1.mask_dataset.get_img_by_time(1))

```
sc.datasets["img"]
sc.datasets["mask"]
sc.datasets["label"]
sc.datasets["TRITC"]
```

In [None]:
from livecell_tracker.segment.utils import match_mask_labels_by_iou
def match_mask_labels_by_iou(seg_label_mask, gt_label_mask, bg_label=0, return_all=False):
    """compute the similarity between ground truth mask and segmentation mask by intersection over union

    Parameters
    ----------
    seg_label_mask : _type_
        _description_
    gt_label_mask : _type_
        _description_
    bg_label : int, optional
        _description_, by default 0
    return_all : bool, optional
        _description_, by default False
    Returns
    -------
        A <gt2seg_map>, mapping ground truth keys to a dictionary of the best matching segmentation label and its iou
    """
    gt2seg_map = {}
    all_gt2seg_iou__map = {}
    # gets all the unique labels in the labeled_seg_mask and gtly_curated_mask
    seg_labels = np.unique(seg_label_mask)
    gt_labels = np.unique(gt_label_mask)

    temp_seg_mask = seg_label_mask.copy()
    temp_gt_mask = gt_label_mask.copy()

    for gt_label in gt_labels:
        if gt_label == bg_label:
            continue
        gt_label_key = gt_label
        all_gt2seg_iou__map[gt_label_key] = []
        gt2seg_map[gt_label_key] = {}
        temp_gt_mask = gt_label_mask.copy()
        # isolates the current cell in the temp gtly_curated_mask and gets its pixels to 1
        temp_gt_mask[temp_gt_mask != gt_label] = 0
        temp_gt_mask[temp_gt_mask != 0] = 1

        best_iou = 0
        for seg_label in seg_labels:
            if seg_label == bg_label:
                continue
            temp_seg_mask = seg_label_mask.copy()

            # isolate the current cell in the temp_seg_mask and set its pixels to 1
            temp_seg_mask[temp_seg_mask != seg_label] = 0
            temp_seg_mask[temp_seg_mask != 0] = 1

            matching_rows, matching_columns = np.where(temp_seg_mask == 1)
            intersection_area = (temp_gt_mask[matching_rows, matching_columns] == 1).sum()
            union_area = temp_gt_mask.sum() + temp_seg_mask.sum() - intersection_area
            iou = intersection_area / union_area
            io_gt = intersection_area / temp_gt_mask.sum()
            io_seg = intersection_area / temp_seg_mask.sum()
            all_gt2seg_iou__map[gt_label_key].append({
                "seg_label": seg_label,
                "iou": iou,
                "io_gt": io_gt,
                "io_seg": io_seg,
            })

            if iou > best_iou:
                best_iou = iou
                gt2seg_map[gt_label_key]["best_iou"] = iou
                gt2seg_map[gt_label_key]["seg_label"] = seg_label
    if return_all:
        return gt2seg_map, all_gt2seg_iou__map
    else:
        return gt2seg_map
    
match_mask_labels_by_iou(sc1.mask_dataset.get_img_by_time(2), sc1.mask_dataset.get_img_by_time(1), return_all=True)

## Apply correction CNN to fix single time-frame case

In [None]:
from livecell_tracker.model_zoo.segmentation.sc_correction import CorrectSegNet
# ckpt = r"/home/ken67/LiveCellTracker-dev/notebooks/lightning_logs/version_real-02/checkpoints/epoch=3720-test_loss=0.0085.ckpt"
# ckpt = r"/home/ken67/LiveCellTracker-dev/notebooks/lightning_logs/version_802/checkpoints/epoch=2570-test_out_matched_num_gt_iou_0.5_percent_real_underseg_cases=0.8548.ckpt"
ckpt = r"/home/ken67/LiveCellTracker-dev/notebooks/lightning_logs/version_v10_02/checkpoints/epoch=2999-global_step=0.ckpt"

model = CorrectSegNet.load_from_checkpoint(ckpt)
model = model.cuda()
model = model.eval()

In [None]:
from livecell_tracker.segment.ou_utils import create_ou_input_from_sc
from torchvision import transforms
from livecell_tracker.preprocess.utils import normalize_img_to_uint8
import torch
transforms = transforms.Compose(
        [
            transforms.Resize(size=(412, 412)),
        ]
)
from livecell_tracker.preprocess.utils import dilate_or_erode_mask
def create_ou_input_from_sc(sc: SingleCellStatic, padding_pixels: int = 0, dtype=float, remove_bg=True, one_object=True, scale=0):
    if remove_bg:
        img_crop = sc.get_contour_img(padding=padding_pixels).astype(dtype)
    else:
        img_crop = sc.get_img_crop(padding=padding_pixels).astype(dtype)
    img_crop = normalize_img_to_uint8(img_crop).astype(dtype)
    if one_object:
        sc_mask = sc.get_contour_mask(padding=padding_pixels)
        sc_mask = dilate_or_erode_mask(sc_mask.astype(np.uint8), scale_factor=scale).astype(bool)
        img_crop[~sc_mask] *= -1
    else:
        img_crop[sc.get_mask_crop(padding=padding_pixels) == 0] *= -1
    return img_crop

def viz_ou_sc_outputs(sc: SingleCellStatic, padding_pixels: int = 0, dtype=float, remove_bg=True, one_object=True, scale=0):
    ou_input = create_ou_input_from_sc(sc, padding_pixels=padding_pixels, dtype=dtype, remove_bg=remove_bg, one_object=one_object, scale=scale)
    ou_input = transforms(torch.tensor([ou_input]))
    ou_input = torch.stack([ou_input, ou_input, ou_input], dim=1)
    # ou_input = ou_input.permute(0, 2, 3, 1)
    ou_input = ou_input.float().cuda()
    output = model(ou_input)

    # visualize the input and all 3 output channels
    fig, axes = plt.subplots(1, 5, figsize=(15, 5))
    axes[0].imshow(ou_input[0, 0].cpu().detach().numpy())
    axes[0].set_title("input")
    axes[1].imshow(output[0, 0].cpu().detach().numpy())
    axes[1].set_title("output c0")
    axes[2].imshow(output[0, 1].cpu().detach().numpy())
    axes[2].set_title("output c1")
    axes[3].imshow(output[0, 2].cpu().detach().numpy())
    axes[3].set_title("output c2")
    axes[4].imshow(sc.get_mask_crop(padding=50, dtype=int))
    axes[4].set_title("original mask")
    plt.show()


selected_sc_list = [single_cells_by_time[2][12], single_cells_by_time[2][13]]
for i, sc in enumerate(selected_sc_list):
    viz_ou_sc_outputs(sc, padding_pixels=50, dtype=float, remove_bg=True, one_object=True, scale=0)

In [None]:
for i, sc in enumerate(selected_sc_list):
    viz_ou_sc_outputs(sc, padding_pixels=50, dtype=float, remove_bg=False, one_object=True, scale=0.1)

In [None]:
for i, sc in enumerate(selected_sc_list):
    viz_ou_sc_outputs(sc, padding_pixels=50, dtype=float, remove_bg=False, one_object=False, scale=0.1)

[TODO] judge if a case is oversegmentation
inputs:
    an input mask with other cells included
    a corected label mask
    


## Apply correction CNN to fix cells from two time frames

In [None]:
from livecell_tracker.segment.utils import compute_match_label_map
from livecell_tracker.core.parallel import parallelize

times = sorted(mask_dataset.times)
inputs = []
for idx in times:
    t1 = idx
    if t1+1 in times:
        t2 = t1+1
    else: 
        continue
    inputs.append((t1, t2, mask_dataset))
label_match_outputs = parallelize(compute_match_label_map, inputs, None)

In [None]:
multiple_maps = []
for t1, t2, label_map in label_match_outputs:
    for label in label_map:
        if len(label_map[label]) > 1:
            # print(t1, t2, label, label_map[label])
            multiple_maps.append((t1, t2, label, label_map[label]))

In [None]:
multiple_maps

In [None]:
time2label2sc = {}
for sc in single_cells:
    if sc.timeframe not in time2label2sc:
        time2label2sc[sc.timeframe] = {}
    label = sc.meta["label_in_mask"]
    time2label2sc[sc.timeframe][label] = sc

In [None]:
t1, t2, label, label_map = multiple_maps[11]

sc1 = time2label2sc[t1][label]
sc2_label = list(label_map)[0]
sc2 = time2label2sc[t2][sc2_label]
sc3_label = list(label_map)[1]
sc3 = time2label2sc[t2][sc3_label]

viz_ou_sc_outputs(sc1, padding_pixels=50, dtype=float, remove_bg=True, one_object=True, scale=0)

In [None]:
for i in range(10):
    print(">" * 80)
    rand_map_idx = np.random.randint(len(multiple_maps))
    t1, t2, label, label_map = multiple_maps[rand_map_idx]
    sc1 = time2label2sc[t1][label]
    sc2_label = list(label_map)[0]
    sc2 = time2label2sc[t2][sc2_label]
    sc3_label = list(label_map)[1]
    sc3 = time2label2sc[t2][sc3_label]
    sc1.show_panel()
    sc2.show_panel()
    sc3.show_panel()
    viz_ou_sc_outputs(sc1)