In [None]:
import numpy as np
import json
from livecell_tracker.core import (
    SingleCellTrajectory,
    SingleCellStatic,
    SingleCellTrajectoryCollection,
)
from livecell_tracker.segment.detectron_utils import (
    convert_detectron_instance_pred_masks_to_binary_masks,
    convert_detectron_instances_to_label_masks,
    segment_images_by_detectron,
    segment_single_img_by_detectron_wrapper,
)
from livecell_tracker.core.datasets import LiveCellImageDataset
from livecell_tracker.preprocess.utils import (
    overlay,
    enhance_contrast,
    normalize_img_to_uint8,
)
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from pathlib import Path
import pandas as pd

## Convert Labelme Json to COCO Json

run the following code once for generating coco json
```
import livecell_tracker.segment
import livecell_tracker.annotation
import livecell_tracker.annotation.labelme2coco
import os
labelme_json_folder = r"""../datasets/a549_ccnn/annotation_data"""
dataset_folder_path = r"""../datasets/a549_ccnn/original_data"""
export_dir = "./notebook_results/correction_cnn_v0.0.0/"
os.makedirs(export_dir, exist_ok=True)
livecell_tracker.annotation.labelme2coco.convert(
    labelme_json_folder,
    export_dir,
    train_split_rate=0.9,
    dataset_folder_path=dataset_folder_path,
    # is_image_in_json_folder=True,
    image_file_ext="tif",
    # image_file_ext="png",
)
```

## Load COCO into SingleCell Objects

In [None]:
from pycocotools.coco import COCO

coco_data = COCO("../datasets/a549_ccnn/a549_ccnn_coco_v0.0.0/train.json")
out_dir = Path("./notebook_results/a549_ccp_vim/train_data_v6/")

coco_data.anns.keys(), coco_data.anns[1].keys(), coco_data.anns[1]["segmentation"][0][
    :5
]


In [None]:
coco_data.imgs.keys(), coco_data.imgs[1].keys(), coco_data.imgs[1]["file_name"],


In [None]:
from typing import List

from livecell_tracker.annotation.coco_utils import coco_to_sc

single_cells = coco_to_sc(coco_data)


# for testing
single_cells = single_cells[:20]

fig, axes = plt.subplots(1, 3, figsize=(10, 5))
cell_id = 10
axes[0].imshow(single_cells[cell_id].get_img_crop(padding=100))
axes[1].imshow(single_cells[cell_id].get_contour_mask_closed_form(padding=100))
axes[2].imshow(single_cells[cell_id].get_contour_mask(padding=100))


In [None]:
len(single_cells)

save a list of single cell objects

In [None]:
SingleCellStatic.write_single_cells_json(
    single_cells, "../datasets/a549_ccnn/single_cells.json"
)


In [None]:
contour_mask = single_cells[cell_id].get_contour_mask(padding=100)
contour_mask.astype(np.uint8)


In [None]:
import cv2 as cv
from livecell_tracker.preprocess.utils import dilate_or_erode_mask

plt.imshow(dilate_or_erode_mask(contour_mask.astype(np.uint8), 1))


In [None]:
sample_sc = single_cells[cell_id]
sample_sc.meta

In [None]:
import glob

raw_img_dataset = sample_sc.img_dataset
seg_data_dir = "../datasets/a549_ccnn/seg_tiles_CCP_A549-VIM_lessThan24hr_Calcein_1mg-ml_DP_Ti2e_2022-9-11"
seg_paths = glob.glob(os.path.join(seg_data_dir, "*.png"))
print("sample seg paths:", seg_paths[:2])
matched_time2seg = {}
# for time, img_path in raw_img_dataset.time2url.items():
#     substr = os.path.basename(img_path).split(".")[0]
#     print("substr:", substr)
#     for seg_path
#     break
corrected_indices = []
for seg_path in seg_paths:
    substr = os.path.basename(seg_path).split(".")[0] # get rid of extension
    substr = substr[4:]  # get rid of seg_ prefix
    img, path, index = raw_img_dataset.get_img_by_url(
        substr, return_path_and_time=True, ignore_missing=True
    )
    if path is None:
        print("skip due to substr not found:", substr)
        continue
    matched_time2seg[index] = seg_path

seg_data = LiveCellImageDataset(time2url=matched_time2seg, ext="png")
sample_sc.mask_dataset = seg_data
assert len(seg_data) == len(raw_img_dataset)


## Generate Synthetic Undersegmentation data

In [None]:
sc1 = single_cells[10]
sc2 = single_cells[1]

padding=20
fig, axes = plt.subplots(2, 3, figsize=(10, 5))
sc1.show(padding=padding, ax=axes[0][0])
sc1.show_contour_img(padding=padding, ax=axes[0][1])
sc1.show_contour_mask(padding=padding, ax=axes[0][2])
sc2.show(padding=padding, ax=axes[1][0])
sc2.show_contour_img(padding=padding, ax=axes[1][1])
sc2.show_contour_mask(padding=padding, ax=axes[1][2])

In [None]:
from livecell_tracker.core.datasets import SingleImageDataset


def check_contour_in_boundary(contour, boundary):
    return np.all(contour >= 0) and np.all(contour < boundary)


def adjust_contour_to_bounds(contour, bounds):
    if not check_contour_in_boundary(contour, bounds):
        contour = contour.copy()
        contour[contour < 0] = 0
        contour = np.where(contour >= bounds, bounds - 1, contour)
    return contour

def shift_contour_randomly(sc_center, contour, bounds):
    random_center = np.random.randint(low=0, high=bounds, size=2)
    shift = random_center - sc_center
    shift = shift.astype(int)
    contour_shifted = contour + shift
    return random_center, contour_shifted, shift

def compute_two_contours_min_distance(contour1, contour2):
    min_dist = np.inf
    for p1 in contour1:
        for p2 in contour2:
            dist = np.linalg.norm(p1 - p2)
            if dist < min_dist:
                min_dist = dist
    return min_dist

def combine_two_scs_monte_carlo(sc1, sc2, bg_img=None, bg_scale=1.5, fix_sc1=False):
    def _gen_empty_bg_img():
        sc1_shape = sc1.get_img_crop().shape
        sc2_shape = sc2.get_img_crop().shape
        bg_shape = np.array([max(sc1_shape[0], sc2_shape[0]), max(sc1_shape[1], sc2_shape[1])])
        bg_shape = (bg_shape * bg_scale).astype(int)
        bg_img = np.zeros(shape=bg_shape)
        return bg_img
    
    if bg_img is None:
        bg_img = _gen_empty_bg_img()

    bg_shape = np.array(bg_img.shape)
    new_img = bg_img.copy()
    new_mask = np.zeros(shape=bg_shape, dtype=np.bool)

    def _add_sc_to_img_helper(sc, new_img, new_sc_mask, shift, in_place=False):
        if not in_place:
            new_img = new_img.copy()
        sc_ori_space_pixel_xy_arr = np.array(new_sc_mask.nonzero()).T - shift
        sc_ori_space_pixel_xy_arr[sc_ori_space_pixel_xy_arr < 0] = 0
        new_img[new_sc_mask] = sc.get_contour_img()[sc_ori_space_pixel_xy_arr[:, 0], sc_ori_space_pixel_xy_arr[:, 1]]
        return new_img

    def add_sc_to_img(sc, new_img, mask, in_place=False, mask_inplace=True, fix_sc_pos=False):
    
        sc_prop = sc.compute_regionprops()
        sc_contour_coords = sc.get_contour_coords_on_crop().astype(int)
        if fix_sc_pos:
            sc_new_center = sc_prop.centroid
            sc_new_contour = sc_contour_coords
            shift = 0
        else:
            sc_new_center, sc_new_contour, shift = shift_contour_randomly(sc_prop.centroid, sc_contour_coords, bounds=bg_shape)
        sc_new_contour = adjust_contour_to_bounds(sc_new_contour, bg_shape)
        new_sc_mask = SingleCellStatic.gen_contour_mask(sc_new_contour, bg_img, bbox=None)
        new_sc_mask_bool = new_sc_mask > 0 # convert to bool
        new_img = _add_sc_to_img_helper(sc, new_img, new_sc_mask, shift, in_place=in_place)

        if mask_inplace:
            mask = mask.copy()
        mask[new_sc_mask_bool] = True
        return new_img, sc_new_contour, mask, shift

    _, sc1_new_contour, sc1_new_mask, shift1 = add_sc_to_img(sc1, new_img, mask=new_mask, in_place=True, fix_sc_pos=fix_sc1)
    _, sc2_new_contour, sc2_new_mask, shift2 = add_sc_to_img(sc2, new_img, mask=new_mask, in_place=True)

    new_sc1 = SingleCellStatic(timeframe=SingleImageDataset.DEFAULT_TIME, contour=sc1_new_contour, img_dataset=SingleImageDataset(new_img), mask_dataset=SingleImageDataset(sc1_new_mask))
    new_sc2 = SingleCellStatic(timeframe=SingleImageDataset.DEFAULT_TIME, contour=sc2_new_contour, img_dataset=SingleImageDataset(new_img), mask_dataset=SingleImageDataset(sc2_new_mask))
    return new_sc1, new_sc2

# has_overlap = np.any(new_sc1_mask & new_sc2_mask)
# print(has_overlap)
# fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# ax = axes[0][0]
# ax.imshow(new_sc1_mask | new_sc2_mask)
# ax = axes[0][1]
# ax.imshow(new_img)
bg_scale = 3.0

def viz_check_combined_sc_result(sc1, sc2):
    fig, axes = plt.subplots(1, 8, figsize=(18, 5))
    ax_idx = 0
    ax = axes[ax_idx]
    sc1.show_whole_img(ax=ax)
    ax.set_title("sc1 whole img")
    ax_idx += 1

    ax = axes[ax_idx]
    sc1.show(ax=ax)
    ax_idx += 1

    ax = axes[ax_idx]
    sc1.show_mask(ax=ax, padding=20)
    ax_idx += 1

    ax = axes[ax_idx]
    sc1.show_contour_img(ax=ax, padding=20)
    ax.set_title("sc1 contour img")
    ax_idx += 1

    ax = axes[ax_idx]
    sc2.show_whole_img(ax=ax)
    ax.set_title("sc2 whole img")
    ax_idx += 1

    ax = axes[ax_idx]
    sc2.show(ax=ax)
    ax_idx += 1

    ax = axes[ax_idx]
    sc2.show_mask(ax=ax, padding=20)
    ax_idx += 1

    ax = axes[ax_idx]
    sc2.show_contour_img(ax=ax, padding=20)
    ax.set_title("sc2 contour img")
    ax_idx += 1

    plt.show()

for i in range(4):
    new_sc1, new_sc2 = combine_two_scs_monte_carlo(sc1, sc2, bg_img=None)
    viz_check_combined_sc_result(new_sc1, new_sc2)

# for i in range(4):
#     new_sc1, new_sc2 = combine_two_scs_monte_carlo(sc1, sc2, bg_img=None, fix_sc1=True)
#     viz_check_combined_sc_result(new_sc1, new_sc2)

In [None]:
viz_check_combined_sc_result(new_sc1, new_sc2)

In [None]:
def gen_synthetic_overlap_scs(sc1, sc2, max_overlap_percent=0.3, bg_scale=2.0, fix_sc1=False):
    # TODO: optimize in the future via computational geometry; now simply use monte carlo for generating required synthetic data
    check_flag = False
    while not check_flag:
        new_sc1, new_sc2 = combine_two_scs_monte_carlo(sc1, sc2, bg_img=None, bg_scale=bg_scale, fix_sc1=fix_sc1)
        overlap_mask = np.logical_and(new_sc1.get_mask(), new_sc2.get_mask())
        overlap_percent = float(np.sum(overlap_mask)) / min(np.sum(new_sc1.get_mask()), np.sum(new_sc2.get_mask()))
        if overlap_percent > 0 and overlap_percent < max_overlap_percent:
            check_flag = True
    return new_sc1, new_sc2, overlap_percent


def gen_synthetic_nonoverlap_scs(min_dist=10, max_dist=100, n=10, ):
    pass


# check results
for i in tqdm(range(4)):
    new_sc1, new_sc2, _ = gen_synthetic_overlap_scs(sc1, sc2)
    viz_check_combined_sc_result(new_sc1, new_sc2)

In [None]:
import itertools
from skimage.measure import find_contours
import cv2

def show_cv2_contours(contours, img):
    im = np.expand_dims(img.astype(np.uint8), axis=2).repeat(3, axis=2) 
    for k, _ in enumerate(contours):
        im = cv.drawContours(im, contours, k, (0, 230, 255), 6)
    plt.imshow(im)
    plt.show()

def merge_two_scs(sc1: SingleCellStatic, sc2: SingleCellStatic):
    new_mask = np.logical_or(sc1.get_mask().astype(bool), sc2.get_mask().astype(bool))
    plt.imshow(new_mask)
    plt.show()
    print(np.unique(new_mask))
    # contours = find_contours(new_mask, fully_connected="high")

    contours, hierarchy = cv2.findContours(new_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    show_cv2_contours(contours, new_mask)
    print("#contours found: ", len(contours))
    # assert len(contours) == 1, "only support single contour"
    assert len(contours) != 0, "must contain at least one contour"
    if len(contours) > 1:
        print("WARNING: more than one contour found, use the first one")
    new_contour = np.array(contours[0])
    # swap xy in new_contour
    new_contour = new_contour[:, :, ::-1]
    new_contour = new_contour.reshape(-1, 2)
    print(np.array(new_contour).shape)
    res_sc = SingleCellStatic(
        timeframe=SingleImageDataset.DEFAULT_TIME,
        contour=new_contour,
        img_dataset=sc1.img_dataset,
        mask_dataset=SingleImageDataset(new_mask),
    )
    return res_sc


num_cells = 3
for i, tmp_scs in enumerate(itertools.combinations(single_cells, num_cells)):
    cur_merged_sc = None
    for j in range(num_cells):
        cur_sc = tmp_scs[j]
        if cur_merged_sc is None:
            cur_merged_sc = new_sc1.copy()
            continue
        cur_merged_sc, new_sc2, _ = gen_synthetic_overlap_scs(cur_merged_sc, cur_sc, fix_sc1=True)
        cur_merged_sc = merge_two_scs(cur_merged_sc, new_sc2)
        print("j=", j)
        viz_check_combined_sc_result(cur_merged_sc, new_sc2)
    

In [None]:
subdir = Path("synthetic_underseg_overlap")
overseg_out_dir = out_dir / subdir
raw_out_dir = overseg_out_dir / "raw"

# seg_out_dir is the directory containing all raw segmentation masks for training
# e.g. the eroded raw segmentation masks
seg_out_dir = overseg_out_dir / "seg"

# raw_seg_dir is the directory containing all raw segmentation masks for recording purposes
raw_seg_dir = overseg_out_dir / "raw_seg_crop"
gt_out_dir = overseg_out_dir / "gt"
gt_label_out_dir = overseg_out_dir / "gt_label_mask"
augmented_seg_dir = overseg_out_dir / "augmented_seg"
raw_transformed_img_dir = overseg_out_dir / "raw_transformed_img"
augmented_diff_seg_dir = overseg_out_dir / "augmented_diff_seg"
meta_path = overseg_out_dir / "metadata.csv"

os.makedirs(raw_out_dir, exist_ok=True)
os.makedirs(seg_out_dir, exist_ok=True)
os.makedirs(raw_seg_dir, exist_ok=True)
os.makedirs(gt_out_dir, exist_ok=True)
os.makedirs(augmented_seg_dir, exist_ok=True)
os.makedirs(gt_label_out_dir, exist_ok=True)
os.makedirs(raw_transformed_img_dir, exist_ok=True)
os.makedirs(augmented_diff_seg_dir, exist_ok=True)


overseg_train_path_tuples = []
augmented_overseg_data = []
filename_pattern = "img-%d_syn-%d.tif"
overseg_metadata = []
underseg_erosion_scale_factors = np.linspace(0, 0.1, 10)
for sc in tqdm(single_cells):
    img_id = sc.timeframe
    for syn_id, overseg_datarow in enumerate(sc.uns[overseg_uns_key]):
        params = overseg_datarow[1]
        img_crop = sc.get_contour_img()
        raw_seg_crop = overseg_datarow[0]
        eroded_seg_crop = overseg_datarow[1]

        combined_gt_label_mask = sc.get_contour_mask()
        assert img_crop.shape == raw_seg_crop.shape == combined_gt_label_mask.shape
        raw_img_path = raw_out_dir / (filename_pattern % (img_id, syn_id))
        seg_img_path = seg_out_dir / (filename_pattern % (img_id, syn_id))
        raw_seg_img_path = raw_seg_dir / (filename_pattern % (img_id, syn_id))
        gt_img_path = gt_out_dir / (filename_pattern % (img_id, syn_id))
        gt_label_img_path = gt_label_out_dir / (filename_pattern % (img_id, syn_id))

        # metadata is a dict, containing params used to genereate our synthetic overseg data
        meta_info = overseg_datarow[2]
        meta_info["raw_img_path"] = raw_img_path
        meta_info["seg_img_path"] = seg_img_path
        meta_info["gt_img_path"] = gt_img_path
        
        overseg_metadata.append(meta_info)

        # call csn augment helper
        csn_augment_helper(img_crop=img_crop, 
            seg_crop=eroded_seg_crop, 
            combined_gt_label_mask=combined_gt_label_mask,
            overseg_raw_seg_crop=raw_seg_crop,
            overseg_raw_seg_img_path=raw_seg_img_path,
            scale_factors=underseg_erosion_scale_factors,
            train_path_tuples=overseg_train_path_tuples,
            augmented_data=augmented_overseg_data,
            img_id=img_id,
            seg_label=syn_id,
            gt_label=sc.timeframe,
            raw_img_path=raw_img_path,
            seg_img_path=seg_img_path,
            gt_img_path=gt_img_path,
            gt_label_img_path=gt_label_img_path,
            augmented_seg_dir=augmented_seg_dir,
            augmented_diff_seg_dir=augmented_diff_seg_dir,
            filename_pattern=filename_pattern,
        )


Now we need to handle two cases:
1) there is any overlap between two objects  
    a) simply dilate and create underseg cases  
2) there is no overlap (future work) 
    b) fill in the pixels in-between the two objects  
    c) 

## Combine data.csv files generated in each subfolder

In [None]:
dataframes = []
for subdir in out_dir.iterdir():
    if subdir.is_dir():
        data_path = subdir / "data.csv"
        dataframe = pd.read_csv(data_path)
        dataframe["subdir"] = subdir.name
        dataframes.append(dataframe)
combined_dataframe = pd.concat(dataframes)
combined_dataframe.to_csv(out_dir / "train_data.csv", index=False)

In [None]:
len(combined_dataframe)