In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
import glob
from pathlib import Path
from PIL import Image, ImageSequence
from tqdm import tqdm
import os
import os.path
# from livecell_tracker import segment
from livecell_tracker import core
from livecell_tracker.core import datasets
from livecell_tracker.core.datasets import LiveCellImageDataset, SingleImageDataset
from skimage import measure
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic
# import detectron2
# from detectron2.utils.logger import setup_logger

# setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
import cv2

# import some common detectron2 utilities
# from detectron2 import model_zoo
# from detectron2.engine import DefaultPredictor
# from detectron2.config import get_cfg
# from detectron2.utils.visualizer import Visualizer
# from detectron2.data import MetadataCatalog, DatasetCatalog
# from livecell_tracker.segment.detectron_utils import gen_cfg

# from livecell_tracker.segment.detectron_utils import (
#     segment_detectron_wrapper,
#     segment_images_by_detectron,
#     convert_detectron_instance_pred_masks_to_binary_masks,
#     convert_detectron_instances_to_label_masks,
# )
# from livecell_tracker.segment.detectron_utils import (
#     convert_detectron_instance_pred_masks_to_binary_masks,
#     convert_detectron_instances_to_label_masks,
#     segment_images_by_detectron,
#     segment_single_img_by_detectron_wrapper,
# )


## Loading single cells from existing mask files

```LiveCellImageDataset, SingleCellImageDataset``` from livecell_tracker.core.datasets allow users to load images with ease, without reading directly into memories.  
In `mask_dataset_path` please make sure that the sorted (alphabetically) file names correspond to the order of times.
 Note that the sorted mechanism provided is simply sort the url (file name) list according to string value. Please note that without proper left trailing zeroes, the order of final times may be incorrect. E.g. string  `T10` (10th file) is less than string `T2`. If you have your customized file patterns, please provide `LiveCellImageDataset` with a `time2url` dictionary to provide necessary time information mapped to file locations for reading time-lapsed data.
`SingleCellImageDataset` takes a single image from the memory and makes it a single time point dataset, which can be handy when you would like to process imaging datasets. 

In [None]:
dataset_dir_path = Path(
    "../datasets/test_data_STAV-A549/DIC_data"
)

mask_dataset_path = Path("../datasets/test_data_STAV-A549/mask_data")

In [None]:
mask_dataset = LiveCellImageDataset(mask_dataset_path, ext="png")
mask_dataset.time2url

In [None]:
dic_dataset = LiveCellImageDataset(dataset_dir_path, ext="tif")

Check if the `time2url` mapping is correct

In [None]:
dic_dataset.time2url

### Convert label masks to single objects

In [None]:
from skimage.measure import regionprops
from livecell_tracker.segment.utils import prep_scs_from_mask_dataset
single_cells = prep_scs_from_mask_dataset(mask_dataset, dic_dataset)

In [None]:
for sc in single_cells:
    assert sc.mask_dataset

In [None]:
# for testing
# single_cells = single_cells[:10]

In [None]:
len(single_cells)

In [None]:
single_cells_by_time = {}
for cell in single_cells:
    if cell.timeframe not in single_cells_by_time:
        single_cells_by_time[cell.timeframe] = []
    single_cells_by_time[cell.timeframe].append(cell)

In [None]:
for time in single_cells_by_time:
    print(time, len(single_cells_by_time[time]))

### Visualize one single cell

In [None]:
sc = single_cells[0]

fig, axes = plt.subplots(1, 4, figsize=(10, 5))
sc.show(ax=axes[0])
sc.show_mask(ax=axes[1])
sc.show_contour_img(ax=axes[2])
sc.show_contour_mask(ax=axes[3])

In [None]:
sc.show_panel(figsize=(15, 5))

In [None]:
sc1 = single_cells[1]
sc2 = single_cells[2]

In [None]:
from livecell_tracker.trajectory.feature_extractors import compute_skimage_regionprops, compute_haralick_features

skimage_features = compute_skimage_regionprops(sc1)
sc1.add_feature("skimage", skimage_features)

In [None]:
# haralick_features = compute_haralick_features(sc1)
# sc1.add_feature("haralick", haralick_features)

In [None]:
sc1.get_feature_pd_series()

Calculate overlap between two single cells

In [None]:
sc1.compute_iou(sc2), sc1.compute_overlap_percent(sc2)

## Tracking based on single cells

In [None]:
from typing import List
from livecell_tracker.track.sort_tracker_utils import (
    gen_SORT_detections_input_from_contours,
    update_traj_collection_by_SORT_tracker_detection,
    track_SORT_bbox_from_contours,
    track_SORT_bbox_from_scs
)


traj_collection = track_SORT_bbox_from_scs(single_cells, dic_dataset, mask_dataset=mask_dataset, max_age=1, min_hits=1)

generate movies

In [None]:
# from livecell_tracker.track.movie import generate_single_trajectory_movie

# for track_id, traj in traj_collection:
#     generate_single_trajectory_movie(traj, save_path=f"./notebook_results/general_tutorial/track_movies/track_{track_id}.gif")

In [None]:
traj_collection.histogram_traj_length()

In [None]:
# for track_id, traj in traj_collection:
#     print("track_id=", track_id)
#     traj.timeframe_to_single_cell[list(traj.timeframe_to_single_cell.keys())[0]].show_panel(figsize=(20, 5))
#     plt.show()
    

In [None]:
%gui qt
from livecell_tracker.core.napari_visualizer import NapariVisualizer
import napari
from skimage import data


In [None]:
from livecell_tracker.core.single_cell import SingleCellStatic, SingleCellTrajectory, SingleCellTrajectoryCollection
import numpy as np
from napari.viewer import Viewer
from livecell_tracker.core.visualizer import Visualizer

from livecell_tracker.core.single_cell import SingleCellStatic, SingleCellTrajectory, SingleCellTrajectoryCollection
import numpy as np
from napari.viewer import Viewer
from livecell_tracker.core.visualizer import Visualizer


viewer = napari.view_image(dic_dataset.to_dask(), name='dic_image', cache=True)
shape_layer = NapariVisualizer.viz_trajectories(traj_collection, viewer, contour_sample_num=20)

### SCTOperator

In [None]:
import copy
from functools import partial
from magicgui import magicgui
from magicgui.widgets import Container, PushButton, Widget, create_widget

class SctOperator():
    CONNECT_MODE = 0
    DISCONNECT_MODE = 1
    ADD_MOTHER_DAUGHER_MODE = 2
    def __init__(self, traj_collection, shape_layer, viewer, operator="connect", magicgui_container=None):
        self.select_info = [] # [cur_sct, cur_sc, selected_shape_index]
        self.operator = operator
        self.setup_shape_layer(shape_layer)
        self.traj_collection = traj_collection
        self.viewer = viewer
        self.magicgui_container = magicgui_container
        self.mode = SctOperator.CONNECT_MODE
        
    def select_shape(self, event, shape_layer=None):
        if shape_layer is None:
            shape_layer = self.shape_layer
        print("current shape layer shape properties: ", event)
        current_properties = shape_layer.current_properties
        assert len(current_properties["sc"]) == 1 and len(current_properties["track_id"]) == 1
        if len(shape_layer.selected_data) > 1:
            print("Please select only one shape at a time for connecting trajectories")
            return
        if len(shape_layer.selected_data) == 0:
            print("No shape selected, please select a shape to connect trajectories")
            return
        selected_shape_index = list(shape_layer.selected_data)[0]
        
        shape_indices_in_select_info = set([info[2] for info in self.select_info])
        if selected_shape_index in shape_indices_in_select_info:
            print("shape already selected, please select another shape")
            return

        cur_sc = current_properties["sc"][0]
        cur_track_id = current_properties["track_id"][0]
        cur_sct = traj_collection[cur_track_id]

        print("setting face color of selected shape...")
        if self.mode == self.CONNECT_MODE:
            selection_face_color = (1, 0, 0, 1)
            selection_status_text = "connect"
        elif self.mode == self.DISCONNECT_MODE:
            selection_face_color = (0, 1, 0, 1)
            selection_status_text = "disconnect"
        elif self.mode == self.ADD_MOTHER_DAUGHER_MODE:
            print("len of select_info", len(self.select_info))
            if len(self.select_info) == 0:
                selection_face_color = (1, 0, 0, 1)
                selection_status_text = "mother"
            else:
                selection_face_color = (0, 0, 1, 1)
                selection_status_text = "daughter"

        face_colors = list(shape_layer.face_color)
        face_colors[selected_shape_index] = selection_face_color
        shape_layer.face_color = face_colors

        properties = shape_layer.properties.copy()
        properties["status"][ selected_shape_index] = selection_status_text
        shape_layer.properties = properties

        shape = shape_layer.data[selected_shape_index]

        # slice_index = viewer.dims.current_step[0]
        self.select_info.append((cur_sct, cur_sc, selected_shape_index))
        print("<selection complete>")
        return cur_sct, cur_sc, selected_shape_index
    
    def connect_two_scts(self):
        assert len(self.select_info) == 2, "Please select two shapes to connect."
        sct1, sc1, shape_index1 = self.select_info[0]
        sct2, sc2, shape_index2 = self.select_info[1]
        if sct1 == sct2:
            print("Skipping connecting two shapes from the same trajectory...")
            return
        print("connecting two shapes from different trajectories...")
        sct1_span = sct1.get_timeframe_span()
        sct2_span = sct2.get_timeframe_span()

        if sct1_span[1] < sct2_span[0] or sct2_span[1] < sct1_span[0]:
            res_traj = sct1.copy()
            res_traj.add_nonoverlapping_sct(sct2)
            self.traj_collection.pop_trajectory(sct1.track_id)
            self.traj_collection.pop_trajectory(sct2.track_id)
            self.traj_collection.add_trajectory(res_traj)
            
            self.viewer.layers.remove(self.shape_layer)
            self.shape_layer = NapariVisualizer.viz_trajectories(self.traj_collection, self.viewer, contour_sample_num=20)
            self.setup_shape_layer(self.shape_layer)
            self.clear_selection()
        else:
            raise NotImplementedError("Two trajectories are overlapping, notImplemented for now...")
        print("connect operator complete!")

    def clear_selection(self):
        print("clearing selection...")
        self.select_info = []
        self.shape_layer.face_color = list(self.original_face_colors)
        self.shape_layer.properties = self.original_properties
        print("clear complete!")

    def setup_shape_layer(self, shape_layer):
        self.shape_layer = shape_layer
        shape_layer.events.current_properties.connect(self.select_shape)
        # w/o deepcopy, the original_face_colors will be changed when shape_layer.face_color is changed...
        self.original_face_colors = copy.deepcopy(list(shape_layer.face_color))
        self.original_properties = copy.deepcopy(shape_layer.properties.copy())
        
    def disconnect_sct(self):
        assert len(self.select_info) == 1, "Please select one shape to disconnect."
        sct, sc, shape_index = self.select_info[0]
        print("disconnecting shape...")
        old_traj = self.traj_collection.pop_trajectory(sct.track_id)
        new_sct1, new_sct2 = old_traj.split(sc.timeframe)
        self.traj_collection.add_trajectory(new_sct1)
        self.traj_collection.add_trajectory(new_sct2)
        self.viewer.layers.remove(self.shape_layer)
        self.shape_layer = NapariVisualizer.viz_trajectories(self.traj_collection, self.viewer, contour_sample_num=20)
        self.setup_shape_layer(self.shape_layer)
        self.clear_selection()
        print("disconnect operator complete!")

    def add_mother_daughter_relation(self):
        assert len(self.select_info) >= 2, "Please select >2 shapes to add mother daughter relation."
        mother_sct, mother_sc, mother_shape_index = self.select_info[0]
        for i in range(1, len(self.select_info)):
            daughter_sct, daughter_sc, daughter_shape_index = self.select_info[i]
            assert mother_sct != daughter_sct, "mother and daughter cannot be from the same trajectory!"
            mother_sct.add_daughter(daughter_sct)
            daughter_sct.add_mother(mother_sct)
        self.clear_selection()
        print("<add mother daughter relation complete>")

    def hide_function_widgets(self):
        for i in range(2, len(self.magicgui_container)):
            self.magicgui_container[i].hide()

    def show_selected_mode_widget(self):
        if self.mode == self.CONNECT_MODE:
            self.magicgui_container[2].show()
        elif self.mode == self.DISCONNECT_MODE:
            self.magicgui_container[3].show()
        elif self.mode == self.ADD_MOTHER_DAUGHER_MODE:
            self.magicgui_container[4].show()
        else:
            raise ValueError("Invalid mode!")

viewer = napari.view_image(dic_dataset.to_dask(), name='dic_image', cache=True)
shape_layer = NapariVisualizer.viz_trajectories(traj_collection, viewer, contour_sample_num=20)
sct_operator = SctOperator(traj_collection, shape_layer, viewer)
sct_operator.setup_shape_layer(shape_layer)

@magicgui(call_button='connect')
def connect_widget():
    print("connect callback fired!")
    sct_operator.connect_two_scts()

@magicgui(call_button='clear selection')
def clear_selection_widget():
    print("clear selection callback fired!")
    sct_operator.clear_selection()

@magicgui(call_button='disconnect')
def disconnect_widget():
    print("disconnect callback fired!")
    sct_operator.disconnect_sct()

@magicgui(call_button='add mother/daughter relation')
def add_mother_daughter_relation_widget():
    print("add mother/daughter relation callback fired!")
    sct_operator.add_mother_daughter_relation()


@magicgui(
    call_button="set mode",
    mode={"choices": ['connect', 'disconnect', 'add mother/daughter relation']}
)
def switch_mode_widget(mode):
    print("switch mode callback fired!")
    if mode == "connect":
        sct_operator.mode = sct_operator.CONNECT_MODE
    elif mode == "disconnect":
        sct_operator.mode = sct_operator.DISCONNECT_MODE
    elif mode == "add mother/daughter relation":
        sct_operator.mode = sct_operator.ADD_MOTHER_DAUGHER_MODE
    sct_operator.hide_function_widgets()
    sct_operator.show_selected_mode_widget()
    sct_operator.clear_selection()


container = Container(widgets=[switch_mode_widget, clear_selection_widget, connect_widget, disconnect_widget, add_mother_daughter_relation_widget], labels=False)

sct_operator.magicgui_container = container
sct_operator.hide_function_widgets()
sct_operator.show_selected_mode_widget()
viewer.window.add_dock_widget(container, name="SCT Operator")

In [None]:
sct_operator.select_info

In [None]:
sct_operator.trajectory_collection

In [None]:
sct_operator.add_mother_daughter_relation()

In [None]:
event = object()
sct_operator.select_shape(event)

In [None]:
shape_layer.text

In [None]:
sct_operator.select_info

In [None]:
sct_operator.magicgui_container[2].hide()

In [None]:
container[0]

In [None]:
sct_operator.select_info

In [None]:
# from qtpy.QtWidgets import QPushButton
# connect_btn = QPushButton('Connect two trajectories')
# connect_btn.resize(100, 100)
# connect_btn.clicked.connect(sct_operator.connect_two_scts)
# widget = viewer.window.add_dock_widget(connect_btn, name="edit trajectory")

# clear_btn = QPushButton('Clear selection')
# clear_btn.resize(100, 100)
# clear_btn.clicked.connect(sct_operator.clear_selection)
# viewer.window.add_dock_widget(clear_btn, name="edit trajectory")

In [None]:
shape = shape_layer.data[0]

In [None]:
# viewer.layers.selection.events.active.connect(lambda x: print(dir(x)))