In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
import glob
from pathlib import Path
from PIL import Image, ImageSequence
from tqdm import tqdm
import os
import os.path
from livecell_tracker import segment
from livecell_tracker import core
from livecell_tracker.core import datasets
from livecell_tracker.core.datasets import LiveCellImageDataset
from skimage import measure
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic
import detectron2
from detectron2.utils.logger import setup_logger

setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
import cv2

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from livecell_tracker.segment.detectron_utils import gen_cfg

from livecell_tracker.segment.detectron_utils import (
    segment_detectron_wrapper,
    segment_images_by_detectron,
    convert_detectron_instance_pred_masks_to_binary_masks,
    convert_detectron_instances_to_label_masks,
)
from livecell_tracker.segment.detectron_utils import (
    convert_detectron_instance_pred_masks_to_binary_masks,
    convert_detectron_instances_to_label_masks,
    segment_images_by_detectron,
    segment_single_img_by_detectron_wrapper,
)

pretrained_model_path = (
    r"./notebook_results/train_log/detectron_train_output__ver0.0.2/model_final.pth"
)
# seg_out_dir = Path(r"""./notebook_results/train_log/detectron_train_output__ver0.0.2/""")
model_dir = Path(r"""./notebook_results/train_log/detectron_train_output__ver0.0.2/""")
out_dir = Path(r"./day0_output")
seg_out_dir = out_dir / "segmentation"
track_out_dir = out_dir / "tracking"
segmentation_result_json_path = seg_out_dir / "segmentation_results.json"
trajectory_collection_path = track_out_dir / "trajectory_collection.json"
assert os.path.exists(model_dir)
os.makedirs(seg_out_dir, exist_ok=True)
os.makedirs(track_out_dir, exist_ok=True)
pretrained_model_path = (
    r"./notebook_results/train_log/detectron_train_output__ver0.0.2/model_final.pth"
)
dataset_dir_path = (
    "../cxa-data/june_2022_data_8bit_png/restart_day0_Group 1_wellA1_RI_MIP_stitched"
)


In [None]:
imgs = core.datasets.LiveCellImageDataset(dataset_dir_path, ext="png", max_img_num=None)


## Segmentation

In [None]:
DETECTRON_CFG = gen_cfg(
    model_path=pretrained_model_path,
    output_dir=str(model_dir),
)
DETECTRON_PREDICTOR = DefaultPredictor(DETECTRON_CFG)


In [None]:
seg_out_dir = Path(seg_out_dir)
segmentation_results = segment_images_by_detectron(imgs, seg_out_dir, cfg=DETECTRON_CFG)
with open(seg_out_dir / "segmentation_results.json", "w+") as f:
    json.dump(segmentation_results, f)


## Tracking

In [None]:
raw_imgs = LiveCellImageDataset(dataset_dir_path, ext="png")

segmentation_results = json.load(open(segmentation_result_json_path, "r"))
segmentation_results


In [None]:
from livecell_tracker.track.sort_tracker_utils import (
    gen_SORT_detections_input_from_contours,
    update_traj_collection_by_SORT_tracker_detection,
    track_SORT_bbox_from_contours,
)

In [None]:
MAX_AGE, MIN_HITS = 5, 3
trajectory_collection = track_SORT_bbox_from_contours(
    segmentation_results, raw_imgs, max_age=MAX_AGE, min_hits=MIN_HITS
)


In [None]:
trajectory_collection.histogram_traj_length()
plt.title("Length distribution of trajectories")
plt.ylabel("Count")
plt.xlabel("Trajectory length")
plt.show()


In [None]:
trajectory_collection.write_json(trajectory_collection_path)

## Trajectory analysis

In [None]:
from livecell_tracker.core.single_cell import SingleCellTrajectoryCollection

traj_collection_json = json.load(open(trajectory_collection_path, "r"))
trajectory_collection = SingleCellTrajectoryCollection().load_from_json_dict(
    traj_collection_json
)


In [None]:
import matplotlib.pyplot as plt

from livecell_tracker.core.single_cell import SingleCellTrajectoryCollection

track_id = 5

def show_trajectory_on_grid(
    trajectory: SingleCellTrajectory,
    nr=4,
    nc=4,
    start_timeframe=20,
    interval=5,
    padding=20,
):
    fig, axes = plt.subplots(nr, nc, figsize=(nc * 4, nr * 4))
    if nr == 1:
        axes = np.array([axes])
    span_range = trajectory.get_timeframe_span_range()
    traj_start, traj_end = span_range
    if start_timeframe < traj_start:
        print(
            "start timeframe larger than the first timeframe of the trajectory, replace start_timeframe with the first timeframe..."
        )
        start_timeframe = span_range[0]
    for r in range(nr):
        for c in range(nc):
            ax = axes[r, c]
            ax.axis("off")
            timeframe = start_timeframe + interval * (r * nc + c)
            if timeframe > traj_end:
                break
            if timeframe not in trajectory.timeframe_set:
                continue
            sc = trajectory.get_single_cell(timeframe)
            sc_img = sc.get_img_crop(padding=padding)
            ax.imshow(sc_img)
            contour_coords = sc.get_img_crop_contour_coords(padding=padding)
            ax.scatter(contour_coords[:, 1], contour_coords[:, 0], s=1, c="r")
            # trajectory_collection[timeframe].plot(axes[r, c])
            ax.set_title(f"timeframe: {timeframe}")
    fig.tight_layout(pad=0.5, h_pad=0.4, w_pad=0.4)


In [None]:
show_trajectory_on_grid(trajectory_collection.get_trajectory(10), padding=30)

In [None]:
counter = 0
for traj in trajectory_collection:
    if (traj.get_timeframe_span_length() < 0):
        continue
    print("traj length:", traj.get_timeframe_span_length())
    counter += 1
    if counter > 10000:
        break
    show_trajectory_on_grid(traj, nr=1, nc=10, start_timeframe=0, interval=5)
    plt.show()


In [None]:
trajectory_collection.histogram_traj_length()


In [None]:
for traj in trajectory_collection:
    print(traj.get_timeframe_span_range(), end=",")


In [None]:
show_trajectory_on_grid(trajectory_collection.get_trajectory(4), padding=30)

In [None]:
traj.get_timeframe_span_range()

In [None]:
from livecell_tracker.trajectory.contour_utils import get_cellTool_contour_points, viz_contours
import matplotlib
import matplotlib.cm
traj = trajectory_collection.get_trajectory(4)
contour_num_points = 500
cell_contours = get_cellTool_contour_points(traj, contour_num_points=contour_num_points)
cmap = matplotlib.cm.get_cmap('viridis')
for idx, contour in enumerate(cell_contours):
    # TODO: idx should be time
    plt.plot(contour.points[:, 0], contour.points[:, 1], c=cmap(idx/len(cell_contours))) 
plt.axis("off")
plt.colorbar(matplotlib.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0, vmax=len(cell_contours)), cmap=cmap))
plt.show()