In [None]:
# %%
import numpy as np
import json
from livecellx.core import (
    SingleCellTrajectory,
    SingleCellStatic,
    SingleCellTrajectoryCollection,

)
from livecellx.core.single_cell import get_time2scs
from livecellx.core.datasets import LiveCellImageDataset
from livecellx.preprocess.utils import (
    overlay,
    enhance_contrast,
    normalize_img_to_uint8,
)
import matplotlib.pyplot as plt
import os
from pathlib import Path
import pandas as pd

# %% [markdown]
# Loading Mitosis trajectory Single Cells

# %%
sctc_path = r"../datasets/DIC-Nikon-gt/tifs_CFP_A549_VIM_120hr_NoTreat_NA_YL_Ti2e_2023-03-22/GH-XY03_traj/traj_XY03.json"
sctc = SingleCellTrajectoryCollection.load_from_json_file(sctc_path)

In [None]:
scs = sctc.get_all_scs()
scs_by_time = get_time2scs(scs)

In [None]:
len(scs)

In [None]:
from typing import List
def create_label_mask_from_scs(scs: List[SingleCellStatic], labels=None):
    label_mask = np.zeros(scs[0].get_mask().shape, dtype=np.int32)
    if labels is None:
        labels = list(range(1, len(scs) + 1))
    for idx, sc in enumerate(scs):
        label_mask[sc.get_mask()] = labels[idx]
    return label_mask

In [None]:
import tqdm
def compute_scs_iou(scs_from: List[SingleCellStatic], scs_to: List[SingleCellStatic], key="iou"):
    for sc1 in tqdm.tqdm(scs_from, total=len(scs_from)):
        if key not in sc1.tmp:
            sc1.tmp[key] = {}
        for sc2 in scs_to:
            if sc2 in sc1.tmp[key]:
                pass
            else:
                sc1.tmp[key][sc2] = sc1.compute_iou(sc2)

def compute_scs_iomin(scs_from: List[SingleCellStatic], scs_to: List[SingleCellStatic], key="iomin"):
    for sc1 in tqdm.tqdm(scs_from, total=len(scs_from)):
        if key not in sc1.tmp:
            sc1.tmp[key] = {}
        for sc2 in scs_to:
            if sc2 in sc1.tmp[key]:
                pass
            else:
                sc1.tmp[key][sc2] = sc1.compute_iomin(sc2)


def find_maps(scs_from: List[SingleCellStatic], scs_to: List[SingleCellStatic], metric_threshold=0.3, metric_key="iomin", min_map_num=None, metric="iomin"):
    if metric == "iou":
        compute_scs_iou(scs_from, scs_to,metric_key)
    elif metric == "iomin":
        compute_scs_iomin(scs_from, scs_to, metric_key)
    else:
        raise ValueError(f"Unknown metric: {metric}")
    scs_map = {}
    for sc1 in scs_from:
        scs_map[sc1] = []
        for sc2 in scs_to:
            if sc1.tmp[metric_key][sc2] > metric_threshold:
                scs_map[sc1].append(sc2)
    # Filter by length of scs_map elements
    if min_map_num is not None:
        scs_map = {k: v for k, v in scs_map.items() if len(v) >= min_map_num}
    return scs_map


# compute_scs_iou(scs_by_time[0], scs_by_time[1])
find_maps(scs_by_time[0], scs_by_time[1])
scs_by_time[0][0].tmp.keys()

In [None]:
img_dataset = scs[0].img_dataset

In [None]:
out_dir = Path("./tmp/EBSS_120hrs_OU_syn")

In [None]:
total_time=2 # For debug
# total_time = len(scs_by_time)

In [None]:
import skimage
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from skimage import measure
from skimage import filters
import skimage.segmentation
from scipy import ndimage

from livecellx.core.io_sc import process_scs_from_single_label_mask
from livecellx.preprocess.utils import dilate_or_erode_label_mask
import tqdm


check_fig_out = out_dir / "crappy_seg_check"
crapp_mask_out = out_dir / "crappy_seg_mask"
check_fig_out.mkdir(exist_ok=True, parents=True)
crapp_mask_out.mkdir(exist_ok=True, parents=True)

crappy_scs_by_time = {}

underseg_dilate_scales = np.linspace(0, 2, 3)
dilate_scale_to_crappy_scs = {
    scale: {} for scale in underseg_dilate_scales
}

dilate_scale_to_gt_scs = {
    scale: {} for scale in underseg_dilate_scales
}
# Infer the max digit required for representing time
file_time_digit = len(str(len(scs_by_time)))
# for time in tqdm.tqdm(scs_by_time):
for time in tqdm.tqdm(range(total_time)):
    crappy_scs_by_time[time] = []
    cur_scs = scs_by_time[time]
    raw_img = cur_scs[0].get_img()

    # Segment the raw image using watershed
    markers = skimage.filters.threshold_otsu(raw_img)
    mask = raw_img > markers
    # mask = create_label_mask_from_scs(cur_scs)
    distance = ndimage.distance_transform_edt(mask)
    local_maxi = skimage.feature.peak_local_max(
        distance, footprint=np.ones((50, 50)), labels=mask, min_distance=40
    )
    markers = skimage.measure.label(local_maxi)

    # # Debug: Check shapes of inputs
    # print("Shape of raw image: ", raw_img.shape)
    # print("Shape of mask: ", mask.shape)
    # print("Shape of distance: ", distance.shape)
    # print("Shape of local_maxi: ", local_maxi.shape)
    # print("Shape of markers: ", markers.shape)

    # Marker is N x 2, transform to image
    markers_img = np.zeros_like(raw_img)
    for idx, marker in enumerate(local_maxi):
        markers_img[marker[0], marker[1]] = idx + 1

    crappy_labels_mask = skimage.segmentation.watershed(-distance, markers_img, mask=mask)

    # Filter out small regions
    filtered_crappy_labels_mask = skimage.morphology.remove_small_objects(crappy_labels_mask, min_size=4000)
    
    # Save filtered masks as npy
    np.save(crapp_mask_out / f"crap_mask_{time:0{file_time_digit}}.npy", filtered_crappy_labels_mask)

    gt_mask = create_label_mask_from_scs(cur_scs)
    # Show raw, gt, crappy, and filtered masks

    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    axs[0].imshow(raw_img, cmap="gray")
    axs[0].set_title("Raw Image")
    axs[1].imshow(gt_mask, cmap="tab20")
    axs[1].set_title("Ground Truth")
    axs[2].imshow(crappy_labels_mask > 0)
    axs[2].set_title("Crappy Watershed")
    axs[3].imshow(filtered_crappy_labels_mask > 0)
    axs[3].set_title("Filtered Watershed")
    # Save figure
    plt.savefig(check_fig_out / f"check_{time:0{file_time_digit}}.png")
    plt.close()

    for ax in axs:
        ax.axis("off")

    crappy_scs = process_scs_from_single_label_mask(filtered_crappy_labels_mask, img_dataset, time)
    crappy_scs_by_time[time] = crappy_scs
    
    figure, axes = plt.subplots(1, len(underseg_dilate_scales) + 2, figsize=(5 * (len(underseg_dilate_scales) + 2), 5))
   
    # Dilate the filtered masks and generate dilated scs from dilated masks
    for idx, scale in enumerate(underseg_dilate_scales):
        dilated_crappy_labels_mask = dilate_or_erode_label_mask(filtered_crappy_labels_mask, scale)
        dilated_crappy_scs = process_scs_from_single_label_mask(dilated_crappy_labels_mask, img_dataset, time)
        dilate_scale_to_crappy_scs[scale][time] = dilated_crappy_scs

        # Visualize dilated masks
        axes[idx].imshow(dilated_crappy_labels_mask > 0)
        axes[idx].set_title(f"Dilated {scale:.2f}")
        axes[idx].axis("off")
    # Show raw img
    axes[-1].imshow(raw_img, cmap="gray")
    axes[-1].set_title("Raw Image")
    # Show gt mask
    axes[-2].imshow(gt_mask, cmap="tab20")
    axes[-2].set_title("Ground Truth")
    plt.savefig(check_fig_out / f"dilated_{time:0{file_time_digit}}.png")
    plt.close()

    # Same process for dilating GT masks
    for idx, scale in enumerate(underseg_dilate_scales):
        gt_mask = create_label_mask_from_scs(cur_scs)
        dilated_gt_mask = dilate_or_erode_label_mask(gt_mask, scale)
        dilated_gt_mask_bin = dilated_gt_mask > 0
        dilated_gt_scs = process_scs_from_single_label_mask(dilated_gt_mask_bin, img_dataset, time)
        dilate_scale_to_gt_scs[scale][time] = dilated_gt_scs
        
    # break




In [None]:
metric_key = "iomin"
for time in range(0, total_time, 1):
    # selected_crappy_scs = dilate_scale_to_crappy_scs[2][time]
    selected_crappy_scs = dilate_scale_to_gt_scs[1][time]
    selected_crappy_scs = dilate_scale_to_gt_scs[2][time]
    # selected_crappy_scs = crappy_scs_by_time[time]
    # cur_maps = find_maps(scs_by_time[time], selected_crappy_scs, iou_threshold=0.1, iou_key="iou", min_map_num=None)
    cur_maps = find_maps(selected_crappy_scs, scs_by_time[time], metric_threshold=0.1, metric_key="iomin", metric="iomin", min_map_num=None)
    print(cur_maps)
    zero_map_num = len([k for k, v in cur_maps.items() if len(v) == 0])
    zero_map_rate = zero_map_num / len(scs_by_time[time])
    multi_map_num = len([k for k, v in cur_maps.items() if len(v) > 1])
    multi_map_rate = multi_map_num / len(scs_by_time[time])
    print(f"Time: {time}, Zero map num: {zero_map_num}, Multi map num: {multi_map_num}")
    print(f"Time: {time}, Zero map rate: {zero_map_rate}, Multi map rate: {multi_map_rate}")

    padding = 50
    # Visualize multi maps
    multi_maps = {k: v for k, v in cur_maps.items() if len(v) > 1}
    print("multimaps: ", multi_maps)
    for sc in multi_maps:
        print(sc.tmp.keys())
        print(metric_key, sc.tmp[metric_key])
        for sc2 in sc.tmp[metric_key]:
            print(sc2.id, sc.tmp[metric_key][sc2])
    for sc1, sc2s in multi_maps.items():
        fig, axs = plt.subplots(1, len(sc2s) + 2, figsize=(5 * (len(sc2s) + 1), 5))
        axs[0].imshow(sc1.get_contour_mask(padding=padding))
        axs[0].set_title(f"Time {time} - sc1")
        for idx, sc2 in enumerate(sc2s):
            axs[idx + 1].imshow(sc2.get_contour_mask(padding=padding, bbox=sc1.bbox))
            axs[idx + 1].set_title(f"Time {time} - sc2_{idx}")

        # Show original img
        raw_img = sc1.get_img_crop(padding=padding)
        axs[-1].imshow(raw_img)
        plt.show()
        # plt.savefig(check_fig_out / f"multi_map_{time:0{file_time_digit}}_{sc1.get_id()}.png")
        # plt.close()

    # Visualize zero maps
    if len(zero_maps) > 0:
        zero_maps = {k: v for k, v in cur_maps.items() if len(v) == 0}
        fig, axs = plt.subplots(1, len(zero_maps) + 1, figsize=(5 * (len(zero_maps) + 1), 5))
        if len(zero_maps) == 0:
            axs = [axs]
        for idx, sc in enumerate(zero_maps.keys()):
            axs[idx].imshow(sc.get_contour_mask(padding=padding))
            axs[idx].set_title(f"Time {time} - sc-{idx}")
        raw_img = sc.get_img_crop(padding=padding)
        axs[-1].imshow(raw_img)
    plt.show()



In [None]:
list(cur_maps.keys())[0].tmp

In [None]:
for sc in selected_crappy_scs:
    sc.show_panel()
    plt.show()
    plt.close()