In [None]:
# %%
import numpy as np
import json
from livecellx.core import (
    SingleCellTrajectory,
    SingleCellStatic,
    SingleCellTrajectoryCollection,

)
from livecellx.core.single_cell import get_time2scs
from livecellx.core.datasets import LiveCellImageDataset
from livecellx.preprocess.utils import (
    overlay,
    enhance_contrast,
    normalize_img_to_uint8,
)
import matplotlib.pyplot as plt
import os
from pathlib import Path
import pandas as pd
from typing import List

# %% [markdown]
# Loading Mitosis trajectory Single Cells

# %%
sctc_path = r"../datasets/DIC-Nikon-gt/tifs_CFP_A549_VIM_120hr_NoTreat_NA_YL_Ti2e_2023-03-22/GH-XY03_traj/traj_XY03.json"
sctc = SingleCellTrajectoryCollection.load_from_json_file(sctc_path)

In [None]:
scs = sctc.get_all_scs()
scs_by_time = get_time2scs(scs)

In [None]:
len(scs)

In [None]:
from livecellx.core.single_cell import create_label_mask_from_scs

In [None]:
img_dataset = scs[0].img_dataset
out_dir = Path("./tmp/EBSS_120hrs_OU_syn")
scs_dir = out_dir/"livecellx_scs"

In [None]:
multi_map_path = scs_dir / "time2multi_maps__id.json"
time2multi_maps__id = json.load(open(multi_map_path))

In [None]:
# Load all_gt_scs and all_dilated_gt_scs
all_gt_scs = SingleCellStatic.load_single_cells_json(scs_dir/"all_gt_scs.json")
all_dilated_gt_scs = SingleCellStatic.load_single_cells_json(scs_dir/"all_dilated_gt_scs.json")


# Recontruct scale -> time -> crappy scs

all_dilate_scale_to_gt_scs = {}
for sc in all_dilated_gt_scs:
    scale = sc.meta["dilate_scale"]
    time = sc.meta["time"]
    if scale not in all_dilate_scale_to_gt_scs:
        all_dilate_scale_to_gt_scs[scale] = {}
    if time not in all_dilate_scale_to_gt_scs[scale]:
        all_dilate_scale_to_gt_scs[scale][time] = []
    all_dilate_scale_to_gt_scs[scale][time].append(sc)

all_gt_scs_by_time = get_time2scs(all_gt_scs)

Track by replacing gt with crappy masks

In [None]:
all_dilate_scale_to_gt_scs.keys()

In [None]:
filtered_tids = []
threshold = 3
for tid, traj in sctc:
    sct = sctc.get_trajectory(tid)
    if len(sct) < threshold:
        continue
    filtered_tids.append(tid)
sctc = sctc.subset(filtered_tids)

sctc.histogram_traj_length()


In [None]:
all_traj_lengths = np.array([_traj.get_timeframe_span_length() for _traj in sctc.track_id_to_trajectory.values()])
plt.hist(all_traj_lengths, bins=20)

In [None]:
sctc.get_time_span()

In [None]:
def time2sct_counts(sctc: SingleCellTrajectoryCollection):
    timespan = sctc.get_time_span()

    time2scts = {}
    for time in range(timespan[0], timespan[1]):
        time2scts[time] = []
        for tid, traj in sctc:
            traj_span = traj.get_timeframe_span() # [start, end]
            if time < traj_span[0] or time > traj_span[1]:
                continue
            time2scts[time].append(traj)

    time2counts = {time: len(scts) for time, scts in time2scts.items()}
    return time2counts
            



In [None]:
threshold = 50
filtered_sctc = sctc.filter_trajectories_by_length(threshold)

In [None]:
from livecellx.track.sort_tracker_utils import track_SORT_bbox_from_scs
selected_scale = 2
crappy_scs_by_time = all_dilate_scale_to_gt_scs[selected_scale]
print("time to be replaced:", crappy_scs_by_time.keys())
all_scs_by_time = get_time2scs(filtered_sctc.get_all_scs())
replaced_scs_by_time = all_scs_by_time.copy()
for time, crappy_scs in crappy_scs_by_time.items():
    replaced_scs_by_time[time] = crappy_scs

crappy_scs = []
for time, scs in replaced_scs_by_time.items():
    crappy_scs.extend(scs)
crappy_sctc = track_SORT_bbox_from_scs(crappy_scs, raw_imgs=crappy_scs[0].img_dataset, max_age=3, min_hits=3)

In [None]:
import seaborn as sns

threshold = 100
alpha = 0.65
crappy_sctc_lengths = np.array([_traj.get_timeframe_span_length() for _traj in crappy_sctc.filter_trajectories_by_length(min_length=threshold).track_id_to_trajectory.values()])
filtered_sctc_lengths = np.array([_traj.get_timeframe_span_length() for _traj in filtered_sctc.filter_trajectories_by_length(min_length=threshold).track_id_to_trajectory.values()])

# Set up the aesthetic environment
sns.set(style="whitegrid")  # Set the background to a white grid for better readability
# plt.rc('font', family='serif')  # Use serif font for a more professional look
plt.rc('text', usetex=False)  # Use LaTeX for text rendering

# Create figure and axes objects
fig, ax = plt.subplots(figsize=(4, 5), dpi=300)  # High resolution for publication quality

# Plot histograms
bins = 20
ax.hist(crappy_sctc_lengths, bins=bins, alpha=alpha, label="Crappy", color='red', edgecolor='black')
ax.hist(filtered_sctc_lengths, bins=bins, alpha=alpha, label="Filtered", color='blue', edgecolor='black')

# Customize the plot with labels, title, and legend
ax.set_xlabel('Trajectory Length', fontsize=12, fontweight='bold')
ax.set_ylabel('Frequency', fontsize=12, fontweight='bold')
ax.set_title('Comparison of Trajectory Lengths', fontsize=14, fontweight='bold')
# ax.legend(frameon=True, facecolor='white', framealpha=0.9, edgecolor='black')

# Add legend
ax.legend(loc='upper right', fontsize=10, title='Legend', title_fontsize=10, shadow=True, fancybox=True)

In [None]:
max_len = 100
min_len = 3
crappy_sctc_lengths = np.array([_traj.get_timeframe_span_length() for _traj in crappy_sctc.filter_trajectories_by_length(min_length=min_len, max_length=max_len).track_id_to_trajectory.values()])
filtered_sctc_lengths = np.array([_traj.get_timeframe_span_length() for _traj in filtered_sctc.filter_trajectories_by_length(min_length=min_len, max_length=max_len).track_id_to_trajectory.values()])

# Set up the aesthetic environment
sns.set(style="whitegrid")  # Set the background to a white grid for better readability
# plt.rc('font', family='serif')  # Use serif font for a more professional look
plt.rc('text', usetex=False)  # Use LaTeX for text rendering

# Create figure and axes objects
fig, ax = plt.subplots(figsize=(4, 5), dpi=300)  # High resolution for publication quality

# Plot histograms
bins = 20
ax.hist(crappy_sctc_lengths, bins=bins, alpha=alpha, label="Crappy", color='red', edgecolor='black')
ax.hist(filtered_sctc_lengths, bins=bins, alpha=alpha, label="Filtered", color='blue', edgecolor='black')

# Customize the plot with labels, title, and legend
ax.set_xlabel('Trajectory Length', fontsize=12, fontweight='bold')
ax.set_ylabel('Frequency', fontsize=12, fontweight='bold')
ax.set_title('Comparison of Trajectory Lengths', fontsize=14, fontweight='bold')
# ax.legend(frameon=True, facecolor='white', framealpha=0.9, edgecolor='black')

# Add legend
ax.legend(loc='upper right', fontsize=10, title='Legend', title_fontsize=10, shadow=True, fancybox=True)

In [None]:
crappy_time2counts = time2sct_counts(crappy_sctc)
gt_time2counts = time2sct_counts(filtered_sctc)

In [None]:
# Visualize the counts
# Visualize the number of cells at each time point

fig, axes = plt.subplots(1, 1, figsize=(4, 5), dpi=300)
plt.plot(list(crappy_time2counts.keys()), list(crappy_time2counts.values()), label="crappy")
plt.plot(list(gt_time2counts.keys()), list(gt_time2counts.values()), label="gt")
plt.xlabel("Time", fontsize=14)
plt.ylabel("Number of Trajectories", fontsize=14)
# plt.title("Number of Trajectories at Each Time Point", fontsize=15)
# xy label font size
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.legend()
plt.show()

In [None]:
for time in all_dilate_scale_to_gt_scs[0]:
    print(time, len(all_dilate_scale_to_gt_scs[0][time]), len(all_gt_scs_by_time[time]))