In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
import glob
from pathlib import Path
from PIL import Image, ImageSequence
from tqdm import tqdm
import os
import os.path
from livecell_tracker import segment
from livecell_tracker import core
from livecell_tracker.core import datasets
from livecell_tracker.core.datasets import LiveCellImageDataset
from skimage import measure
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic
import detectron2
from detectron2.utils.logger import setup_logger
import tqdm
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
import cv2

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from livecell_tracker.segment.detectron_utils import gen_cfg

from livecell_tracker.segment.detectron_utils import (
    segment_detectron_wrapper,
    segment_images_by_detectron,
    convert_detectron_instance_pred_masks_to_binary_masks,
    convert_detectron_instances_to_label_masks,
)
from livecell_tracker.segment.detectron_utils import (
    convert_detectron_instance_pred_masks_to_binary_masks,
    convert_detectron_instances_to_label_masks,
    segment_images_by_detectron,
    segment_single_img_by_detectron_wrapper,
)


dataset_dir_path = Path(
    "../datasets/EBSS_Starvation/tif_STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21/XY1"
)

mask_dataset_path = Path("../datasets/EBSS_Starvation/tif_STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21/out/XY9/seg")

In [None]:
def get_time_from_path(path):
    """example path: STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21_T287_XY09_TRITC.tif"""
    idx = 0
    strs = path.split("_")
    while idx < len(strs) - 1:
        if strs[idx][:2] == "XY":
            break
        idx += 1
    idx -= 1
    return int(strs[idx][1:])
get_time_from_path("example path: STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21_T287_XY09_DIC.tif")

In [None]:
mask_time2url = {}
mask_paths = sorted(glob.glob(str(mask_dataset_path / "*.png")))
for mask_path in mask_paths:
    mask_time2url[get_time_from_path(mask_path)] = mask_path
mask_dataset = LiveCellImageDataset(ext="png", time2url=mask_time2url)
len(mask_dataset)

In [None]:

time2url = {}
img_paths = sorted(glob.glob(str(dataset_dir_path / "*_DIC.tif")))

for img_path in img_paths:
    time = get_time_from_path(img_path)
    time2url[time] = img_path

dic_dataset = LiveCellImageDataset(dataset_dir_path, time2url=time2url, ext="tif")

In [None]:
list(dic_dataset.time2url.items())[:5]

In [None]:
len(dic_dataset.time2url)

check co-existence of times

In [None]:
for time in mask_dataset.time2url:
    assert time in dic_dataset.time2url

for time in dic_dataset.time2url:
    assert time in mask_dataset.time2url

Convert label masks to single objects

In [None]:
from multiprocessing import Pool
from skimage.measure import regionprops


def process_mask(mask_dataset, dic_dataset, time):
    seg_mask = mask_dataset.get_img_by_time(time)
    props_list = regionprops(seg_mask)
    _scs = []
    for prop in props_list:
        _scs.append(
            SingleCellStatic(
                timeframe=time,
                img_dataset = dic_dataset,
                mask_dataset = mask_dataset,
                bbox=prop.bbox,
                contour=prop.coords,
            )
        )
    return _scs

def process_mask_wrapper(args):
    return process_mask(*args)

def prep_scs_from_mask_dataset(mask_dataset, dic_dataset, cores=None):
    scs = []

    inputs = [(mask_dataset, dic_dataset, time) for time in mask_dataset.time2url.keys()]
    pool = Pool(processes=cores)
    for _scs in tqdm.tqdm(pool.imap_unordered(process_mask_wrapper, inputs), total=len(inputs)):
        scs.extend(_scs)
    pool.close()
    pool.join()
    return scs

single_cells = prep_scs_from_mask_dataset(mask_dataset, dic_dataset, cores=None)

In [None]:
# from skimage.measure import regionprops
# single_cells = []

# for time in tqdm.tqdm(mask_dataset.time2url):
#     img = dic_dataset.get_img_by_time(time)
#     seg_mask = mask_dataset.get_img_by_time(time)
#     props_list = regionprops(seg_mask)
#     for prop in props_list:
#         single_cells.append(
#             SingleCellStatic(
#                 timeframe=time,
#                 img_dataset = dic_dataset,
#                 mask_dataset = mask_dataset,
#                 bbox=prop.bbox,
#                 contour=prop.coords,
#             )
#         )


check the number of single cells 

In [None]:
len(single_cells)

In [None]:
for i in range(10):
    sc = random.sample(single_cells, 1)[0]
    print("sc timeframe: ", sc.timeframe)

In [None]:
dic_dataset[6]

In [None]:
for i, _ in enumerate(single_cells):
    sc = random.sample(single_cells, 1)[0]
    print("sc time: ", sc.timeframe)
    sc.show_panel(padding=50)
    plt.show()
    if i >= 1:
        break

In [None]:
# for testing
# single_cells = single_cells[:10]

In [None]:
single_cells_by_time = {}
for cell in single_cells:
    if cell.timeframe not in single_cells_by_time:
        single_cells_by_time[cell.timeframe] = []
    single_cells_by_time[cell.timeframe].append(cell)

In [None]:
times = sorted(single_cells_by_time.keys())
for time in times[:5]:
    print(time, len(single_cells_by_time[time]))

Visualize one single cell

In [None]:
sc = single_cells[0]

fig, axes = plt.subplots(1, 4, figsize=(10, 5))
sc.show(ax=axes[0])
sc.show_mask(ax=axes[1])
sc.show_contour_img(ax=axes[2])
sc.show_contour_mask(ax=axes[3])

In [None]:
from livecell_tracker.segment.utils import match_mask_labels_by_iou
def match_mask_labels_by_iou(seg_label_mask, gt_label_mask, bg_label=0, return_all=False):
    """compute the similarity between ground truth mask and segmentation mask by intersection over union

    Parameters
    ----------
    seg_label_mask : _type_
        _description_
    gt_label_mask : _type_
        _description_
    bg_label : int, optional
        _description_, by default 0
    return_all : bool, optional
        _description_, by default False
    Returns
    -------
        A <gt2seg_map>, mapping ground truth keys to a dictionary of the best matching segmentation label and its iou
    """
    gt2seg_map = {}
    all_gt2seg_iou__map = {}
    # gets all the unique labels in the labeled_seg_mask and gtly_curated_mask
    seg_labels = np.unique(seg_label_mask)
    gt_labels = np.unique(gt_label_mask)

    temp_seg_mask = seg_label_mask.copy()
    temp_gt_mask = gt_label_mask.copy()

    for gt_label in gt_labels:
        if gt_label == bg_label:
            continue
        gt_label_key = gt_label
        all_gt2seg_iou__map[gt_label_key] = []
        gt2seg_map[gt_label_key] = {}
        temp_gt_mask = gt_label_mask.copy()
        # isolates the current cell in the temp gtly_curated_mask and gets its pixels to 1
        temp_gt_mask[temp_gt_mask != gt_label] = 0
        temp_gt_mask[temp_gt_mask != 0] = 1

        best_iou = 0
        for seg_label in seg_labels:
            if seg_label == bg_label:
                continue
            temp_seg_mask = seg_label_mask.copy()

            # isolate the current cell in the temp_seg_mask and set its pixels to 1
            temp_seg_mask[temp_seg_mask != seg_label] = 0
            temp_seg_mask[temp_seg_mask != 0] = 1

            matching_rows, matching_columns = np.where(temp_seg_mask == 1)
            intersection_area = (temp_gt_mask[matching_rows, matching_columns] == 1).sum()
            union_area = temp_gt_mask.sum() + temp_seg_mask.sum() - intersection_area
            iou = intersection_area / union_area
            io_gt = intersection_area / temp_gt_mask.sum()
            io_seg = intersection_area / temp_seg_mask.sum()
            all_gt2seg_iou__map[gt_label_key].append({
                "seg_label": seg_label,
                "iou": iou,
                "io_gt": io_gt,
                "io_seg": io_seg,
            })

            if iou > best_iou:
                best_iou = iou
                gt2seg_map[gt_label_key]["best_iou"] = iou
                gt2seg_map[gt_label_key]["seg_label"] = seg_label
    if return_all:
        return gt2seg_map, all_gt2seg_iou__map
    else:
        return gt2seg_map

In [None]:
sc1 = single_cells_by_time[1][0]
sc2 = single_cells_by_time[2][0]
match_mask_labels_by_iou(sc1.mask_dataset.get_img_by_time(1), sc1.mask_dataset.get_img_by_time(2), return_all=True)

In [43]:
t1, t2 = 1, 2
mask1 = sc1.mask_dataset.get_img_by_time(t1)
mask2 = sc1.mask_dataset.get_img_by_time(t2)

def compute_match_label_map(t1, t2, mask_dataset, iou_threshold=0.2):
    mask1 = mask_dataset.get_img_by_time(t1)
    mask2 = mask_dataset.get_img_by_time(t2)
    _, score_dict = match_mask_labels_by_iou(mask1, mask2, return_all=True)
    iou_threshold = 0.2
    label_map = {}
    for label_1 in score_dict:
        label_map[label_1] = {}
        for score_info in score_dict[label_1]:
            if score_info["iou"] > iou_threshold:
                label_map[label_1][score_info["seg_label"]] = {
                    "iou": score_info["iou"]
                }
    return t1, t2, label_map

In [44]:
from functools import partial
def wrap_func(func, args):
    return func(*args)

def parallelize(func, inputs, cores=None):
    pool = Pool(processes=cores)
    outputs = []
    for output in tqdm.tqdm(pool.imap_unordered(partial(wrap_func, func), inputs), total=len(inputs)):
        outputs.append(output)
    pool.close()
    pool.join()
    return outputs

times = sorted(mask_dataset.times)
inputs = []
for idx in range(len(times) - 1):
    t1, t2 = times[idx], times[idx + 1]
    inputs.append((t1, t2, mask_dataset))

outputs = parallelize(compute_match_label_map, inputs, None)

  1%|          | 3/290 [02:28<2:39:48, 33.41s/it]  