# Create Napari UI for editing SingleCellTrajectoryCollection (SCTC)

### Loading sample data and create a sample trajectory collection

In [None]:
from pathlib import Path
from livecellx.sample_data import tutorial_three_image_sys

from livecellx.core.io_sc import prep_scs_from_mask_dataset

from livecellx.core.single_cell import (
    SingleCellStatic,
    SingleCellTrajectory,
    SingleCellTrajectoryCollection,
)


# scs_path = r"D:\LiveCellTracker-dev\notebooks\application_results\5days_notreat_XY03_max_age=3_min_hits=1\5days_notreat_XY03_max_age=3_min_hits=1.json"
scs_path = (
    Path(
        r"D:\LiveCellTracker-dev\notebooks\application_results\Gaohan_5days_notreat\pos_scs"
    )
    / "XY03_scs.json"
)
scs_path = scs_path.as_posix()
scs = SingleCellStatic.load_single_cells_json(scs_path)
# sctc = SingleCellTrajectoryCollection.load_from_json_file(scs_path)

# single_cells = prep_scs_from_mask_dataset(mask_dataset, dic_dataset)

In [None]:
len(scs)

In [None]:
from typing import List
from livecellx.track.sort_tracker_utils import (
    gen_SORT_detections_input_from_contours,
    update_traj_collection_by_SORT_tracker_detection,
    track_SORT_bbox_from_contours,
    track_SORT_bbox_from_scs,
)


dic_dataset = scs[0].img_dataset
mask_dataset = scs[0].mask_dataset

traj_collection = track_SORT_bbox_from_scs(
    scs, dic_dataset, mask_dataset=None, max_age=5, min_hits=1
)
# traj_collection = track_SORT_bbox_from_scs(single_cells, dic_dataset, mask_dataset=mask_dataset, max_age=0, min_hits=1)

### Call livecellx.core.sct_operator.create_sctc_edit_viewer_by_interval to create the interface
If key does not work after you click the slice bar at the bottom, please click the canvas (middle) and try again.

In [None]:
traj_collection.remove_empty_sct()

In [None]:
len(traj_collection)

In [None]:
import tqdm 
dic_dataset.max_cache_size = 1500
for time in tqdm.tqdm(dic_dataset.times):
    dic_dataset.get_img_by_time(time)
    

In [None]:
traj_collection.remove_empty_sct()
len(traj_collection)

In [None]:
from livecellx.core.sct_operator import (
    create_scs_edit_viewer,
    SctOperator,
    create_scs_edit_viewer_by_interval,
    _get_viewer_sct_operator,
    create_sctc_edit_viewer_by_interval,
)
import livecellx
import importlib

importlib.reload(livecellx.core.single_cell)
importlib.reload(livecellx.core.sct_operator)
importlib.reload(livecellx.core.sc_seg_operator)
importlib.reload(livecellx.core.napari_visualizer)

# sct_opeartor = livecellx.core.sct_operator.create_sctc_edit_viewer_by_interval(traj_collection, img_dataset=dic_dataset, span_interval=3)

sct_opeartor = livecellx.core.sct_operator.create_sctc_edit_viewer_by_interval(
    traj_collection, img_dataset=dic_dataset, span_interval=200, contour_sample_num=15
)
# sct_opeartor = livecellx.core.sct_operator.create_sctc_edit_viewer_by_interval(traj_collection, img_dataset=dic_dataset, span_interval=3, contour_sample_num=100, viewer=sct_operator.viewer)

In [None]:
sct_opeartor.meta["_contour_sample_num"] = 20
sct_opeartor.meta["_span_interval"] = 1000

In [None]:
sct_opeartor.traj_collection.write_json(
    "./application_results/Gaohan_5days_notreat/XY03_sctc.json",
    dataset_json_dir="./application_results/Gaohan_5days_notreat/sctc_datasets",
)

In [None]:
# import livecellx
# importlib.reload(livecellx.core.single_cell)
# _test_load_sctc = livecellx.core.single_cell.SingleCellTrajectoryCollection.load_from_json_file(
#     "./application_results/Gaohan_5days_notreat/XY03_sctc.json",
#     parallel=False
# )

In [None]:
import pytorch_lightning as pl
pl.__version__

In [None]:
from livecellx.core.sc_seg_operator import ScSegOperator

ckpt_path = r"D:\LiveCellTracker-dev\notebooks\notebook_results\csn_models\v11-01_epoch=90_test_loss=0.0240-best.ckpt"
ScSegOperator.load_default_csn_model(path=ckpt_path, cuda=True)

In [None]:
sc_seg_operator = sct_opeartor.sc_operators[0]
sc_seg_operator.correct_segment(
    model=ScSegOperator.DEFAULT_CSN_MODEL,
    create_ou_input_kwargs={
            "padding_pixels": 50,
            "dtype": float,
            "remove_bg": False,
            "one_object": True,
            "scale": 0,
        },)
sc_seg_operator.csn_correct_seg_callback()

Test loading speed of Napari shapes layer, skip if not relevant to your research.

In [None]:
from livecellx.core.napari_visualizer import NapariVisualizer
from livecellx.core.single_cell import filter_sctc_by_time_span
import numpy as np

trajectories = sct_opeartor.traj_collection
trajectories = filter_sctc_by_time_span(trajectories, (0, 10))
all_shapes = []
track_ids = []
all_scs = []
all_status = []
bbox = None
contour_sample_num = 50
for track_id, traj in tqdm.tqdm(trajectories):
    traj_shapes, scs = traj.get_scs_napari_shapes(
        bbox=bbox, contour_sample_num=contour_sample_num, return_scs=True
    )
    # traj_shapes = [np.array(shape) for shape in traj_shapes]
    all_shapes.extend(traj_shapes)
    track_ids.extend([int(track_id)] * len(traj_shapes))
    all_scs.extend(scs)
    all_status.extend([""] * len(traj_shapes))

import napari
viewer  = napari.Viewer()
shape_layer = viewer.add_shapes(
    all_shapes,
    face_colormap="viridis",
    shape_type="polygon",
    name="trajectories",
)

In [None]:
len(all_shapes)

### Programmatically create CSN inputs and predict with CSN models

In [None]:
from livecellx.segment.ou_utils import create_ou_input_from_sc
from livecellx.model_zoo.segmentation.eval_csn import viz_sample_v3
import torch
from torchvision import transforms


sc = sc_seg_operator.sc
ou_input = create_ou_input_from_sc(sc, **{
            "padding_pixels": 100,
            "dtype": float,
            "remove_bg": False,
            "one_object": True,
            "scale": 0,
        })
# ou_input = create_ou_input_from_sc(self.sc, **create_ou_input_kwargs)
original_shape = ou_input.shape
input_transforms = transforms.Compose(
    [
        transforms.Resize(size=(412, 412)),
    ]
)
ou_input = input_transforms(torch.tensor([ou_input]))
ou_input = torch.stack([ou_input, ou_input, ou_input], dim=1)
ou_input = ou_input.float()

gt_mask = torch.tensor(sc.get_contour_mask(crop=True, padding=50)).float()
sample = {
    "input": ou_input.squeeze(0),
    "gt_mask": torch.stack([gt_mask, gt_mask, gt_mask], dim=1).squeeze(0),
}
viz_sample_v3(sample, ScSegOperator.DEFAULT_CSN_MODEL, sc.get_contour_mask())

## Pre-check scs and trajectories by IOU

In [None]:
traj_collection.get_all_track_ids()

In [None]:
scs_by_time = {}
all_scs = traj_collection.get_all_scs()
for sc in all_scs:
    if sc.timeframe not in scs_by_time:
        scs_by_time[sc.timeframe] = []
    scs_by_time[sc.timeframe].append(sc)

In [None]:
all_trajs = traj_collection.get_all_trajectories()
for traj in tqdm.tqdm(all_trajs):
    times = traj.timeframe_to_single_cell.keys()
    times = sorted(times)
    for i in tqdm.tqdm(range(len(times) - 1)):
        sc = traj.timeframe_to_single_cell[times[i]]
        next_sc = traj.timeframe_to_single_cell[times[i + 1]]
        sc.uns["next_sc"] = next_sc
        sc.uns["nxt_sc_iou"] = sc.compute_iou(next_sc)
        next_sc.uns["prev_sc"] = sc


In [None]:
len(all_trajs)

In [None]:
times = sorted(scs_by_time.keys())

def _iou_compute_wrapper(cur_scs, next_scs, cur_time, next_time):
    for sc in tqdm.tqdm(cur_scs):
        sc.uns["iou_map"] = {}
        for next_sc in next_scs:
            iou = sc.compute_iou(next_sc)
            sc.uns["iou_map"][next_sc.id] = iou
    return cur_scs, cur_time, next_time


# for time in tqdm.tqdm(range(len(times) - 1)):
#     cur_scs = scs_by_time[times[time]]
#     next_scs = scs_by_time[times[time + 1]]
#     for sc in tqdm.tqdm(cur_scs):
#         sc.uns["iou_map"] = {}
#         for next_sc in next_scs:
#             iou = sc.compute_iou(next_sc)
#             sc.uns["iou_map"][next_sc.id] = iou

iou_parallel_inputs = []
for time in tqdm.tqdm(range(len(times) - 1)):
    cur_scs = scs_by_time[times[time]]
    next_scs = scs_by_time[times[time + 1]]
    iou_parallel_inputs.append((cur_scs, next_scs, times[time], times[time + 1]))

from livecellx.core.parallel import parallelize
outputs = parallelize(_iou_compute_wrapper, iou_parallel_inputs, cores=16)

for output in tqdm.tqdm(outputs):
    cur_scs, cur_time, next_time = output
    for sc in cur_scs:
        scs_by_time[cur_time] = cur_scs


In [None]:
SingleCellStatic.write_single_cells_json(all_scs, "./application_results/Gaohan_5days_notreat/XY03_scs_iou.json")