In [None]:
# %%
import numpy as np
import json
from livecellx.core import (
    SingleCellTrajectory,
    SingleCellStatic,
    SingleCellTrajectoryCollection,

)
from livecellx.core.single_cell import get_time2scs
from livecellx.core.datasets import LiveCellImageDataset
from livecellx.preprocess.utils import (
    overlay,
    enhance_contrast,
    normalize_img_to_uint8,
)
import matplotlib.pyplot as plt
import os
from pathlib import Path
import pandas as pd
from typing import List

# %% [markdown]
# Loading Mitosis trajectory Single Cells

# %%
sctc_path = r"../datasets/DIC-Nikon-gt/tifs_CFP_A549_VIM_120hr_NoTreat_NA_YL_Ti2e_2023-03-22/GH-XY03_traj/traj_XY03.json"
sctc = SingleCellTrajectoryCollection.load_from_json_file(sctc_path)

In [None]:
out_dir = Path("./tmp/EBSS_120hrs_OU_syn")
scs_dir = out_dir/"livecellx_scs"

In [None]:
multi_map_path = scs_dir / "time2multi_maps__id.json"
time2multi_maps__id = json.load(open(multi_map_path))

In [None]:
# Load all_gt_scs and all_dilated_gt_scs
all_gt_scs = SingleCellStatic.load_single_cells_json(scs_dir/"all_gt_scs.json")
all_dilated_gt_scs = SingleCellStatic.load_single_cells_json(scs_dir/"all_dilated_gt_scs.json")


# Recontruct scale -> time -> crappy scs

all_dilate_scale_to_gt_scs = {}
for sc in all_dilated_gt_scs:
    scale = sc.meta["dilate_scale"]
    time = sc.meta["time"]
    if scale not in all_dilate_scale_to_gt_scs:
        all_dilate_scale_to_gt_scs[scale] = {}
    if time not in all_dilate_scale_to_gt_scs[scale]:
        all_dilate_scale_to_gt_scs[scale][time] = []
    all_dilate_scale_to_gt_scs[scale][time].append(sc)




In [None]:
all_gt_scs = all_gt_scs + all_dilated_gt_scs

id2sc = {
    sc.id: sc
    for sc in all_gt_scs
}

In [None]:
sc_multimap = {

}

for time in time2multi_maps__id:
    for  info_dict in time2multi_maps__id[time]:
        id, ids = info_dict["map_from"], info_dict["map_to"]
        _sc = id2sc[id]
        _scs = [id2sc[id] for id in ids]
        sc_multimap[_sc] = _scs

print("# Mapping (U-seg) cases:", len(sc_multimap))    

In [None]:
import matplotlib.pyplot as plt

padding = 50
# Visualize the multimap cases
for i in range(5):
    sc1, sc2s = list(sc_multimap.items())[i]
    fig, axs = plt.subplots(1, len(sc2s) + 2, figsize=(5 * (len(sc2s) + 2), 5))
    axs[0].imshow(sc1.get_contour_mask(padding=padding))
    axs[0].set_title(f"Time {time} - sc1")
    for idx, sc2 in enumerate(sc2s):
        axs[idx + 1].imshow(sc2.get_contour_mask(padding=padding, bbox=sc1.bbox))
        axs[idx + 1].set_title(f"Time {time} - sc2_{idx}")
    raw_img = sc1.get_img_crop(padding=padding)
    axs[-1].imshow(raw_img)


In [None]:
import livecellx
from livecellx.core.single_cell import create_label_mask_from_scs


In [14]:
import tqdm


def gen_underseg_cases_from_multimaps(sc_multimap, out_subdir, filename_pattern="img-%d_scId-%s.tif"):
    print("Generating underseg cases from multimap, output to:", out_subdir)
    raw_out_dir = out_subdir / "raw"
    seg_out_dir = out_subdir / "seg"
    gt_out_dir = out_subdir / "gt"
    gt_label_out_dir = out_subdir / "gt_label_mask"
    augmented_seg_dir = out_subdir / "augmented_seg"
    raw_transformed_img_dir = out_subdir / "raw_transformed_img"
    augmented_diff_seg_dir = out_subdir / "augmented_diff_seg"

    os.makedirs(raw_out_dir, exist_ok=True)
    os.makedirs(seg_out_dir, exist_ok=True)
    os.makedirs(gt_out_dir, exist_ok=True)
    os.makedirs(augmented_seg_dir, exist_ok=True)
    os.makedirs(gt_label_out_dir, exist_ok=True)
    os.makedirs(raw_transformed_img_dir, exist_ok=True)
    os.makedirs(augmented_diff_seg_dir, exist_ok=True)

    scale_factors = [0] # We 
    train_path_tuples = []
    augmented_data = []

    multimaps = list(sc_multimap.items())

    for sc, scs in tqdm.tqdm(multimaps):
        img_id = sc.timeframe
        seg_label = sc.id
        # (img_crop, seg_crop, combined_gt_label_mask) = underseg_overlay_gt_masks(seg_label, scs, padding_scale=2)
        img_crop = sc.get_img_crop()
        seg_crop = sc.get_contour_mask()
        # Only 1 gt mask for mask cases, seg_crop is sufficient
        combined_gt_label_mask = create_label_mask_from_scs(scs, bbox=sc.bbox)

        filename = (filename_pattern % (img_id, seg_label))
        raw_img_path = raw_out_dir / filename
        seg_img_path = seg_out_dir / filename
        gt_img_path = gt_out_dir / filename
        gt_label_img_path = gt_label_out_dir / filename

        # call csn augment helper
        livecellx.segment.ou_utils.csn_augment_helper(img_crop=img_crop, 
            seg_label_crop=seg_crop, 
            combined_gt_label_mask=combined_gt_label_mask,
            scale_factors=scale_factors,
            train_path_tuples=train_path_tuples,
            augmented_data=augmented_data,
            img_id=img_id,
            seg_label=seg_label,
            gt_label=None,
            raw_img_path=raw_img_path,
            seg_img_path=seg_img_path,
            gt_img_path=gt_img_path,
            gt_label_img_path=gt_label_img_path,
            augmented_seg_dir=augmented_seg_dir,
            augmented_diff_seg_dir=augmented_diff_seg_dir,
            raw_transformed_img_dir=raw_transformed_img_dir,
            df_save_path=None,
            filename_pattern="img-%d_scId-%s.tif"
        )


    pd.DataFrame(
        train_path_tuples,
        columns=["raw", "seg", "gt", "raw_seg", "scale", "aug_diff_mask", "gt_label_mask", "raw_transformed_img"],
    ).to_csv(out_subdir / "data.csv", index=False)




# Randomly split the sc_multimap into 2 parts: train and test
np.random.seed(0)
sc_multimap_items = list(sc_multimap.items())
np.random.shuffle(sc_multimap_items)
train_size = int(0.8 * len(sc_multimap_items))
train_sc_multimap = dict(sc_multimap_items[:train_size])
test_sc_multimap = dict(sc_multimap_items[train_size:])
print("# Train cases:", len(train_sc_multimap))
print("# Test cases:", len(test_sc_multimap))

train_out_dir = Path("./notebook_results/a549_ccp_vim/train_underseg_EBSS_120hrs_syn/")
train_out_subdir = train_out_dir / "EBSS_120hrs_gt_dilated_syn_underseg_interval_50"
gen_underseg_cases_from_multimaps(train_sc_multimap, train_out_subdir)

test_out_dir = Path("./notebook_results/a549_ccp_vim/test_underseg_EBSS_120hrs_syn/")
test_out_subdir = test_out_dir / "EBSS_120hrs_gt_dilated_syn_underseg_interval_50"
gen_underseg_cases_from_multimaps(test_sc_multimap, test_out_subdir)


# Train cases: 56
# Test cases: 14
Generating underseg cases from multimap, output to: notebook_results/a549_ccp_vim/train_underseg_EBSS_120hrs_syn/EBSS_120hrs_gt_dilated_syn_underseg_interval_50


100%|██████████| 56/56 [06:30<00:00,  6.98s/it]


Generating underseg cases from multimap, output to: notebook_results/a549_ccp_vim/test_underseg_EBSS_120hrs_syn/EBSS_120hrs_gt_dilated_syn_underseg_interval_50


100%|██████████| 14/14 [01:08<00:00,  4.89s/it]
