# Prepare mitosis time series data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
import glob
from pathlib import Path
from PIL import Image, ImageSequence
from tqdm import tqdm
import os
import os.path
# from livecell_tracker import segment
from livecell_tracker import core
from livecell_tracker.core import datasets
from livecell_tracker.core.datasets import LiveCellImageDataset, SingleImageDataset
from skimage import measure
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic

In [None]:
# sample_json_dir = Path("./EBSS_starvation_24h_xy16_annotation")

sample_json_dirs = [Path(r"./datasets/test_scs_EBSS_starvation/XY1/annotations"), Path(r"./datasets/test_scs_EBSS_starvation/XY16/annotations")]

def load_class2samples_from_json_dir(sample_json_dir: Path, class_subfolders = ["mitosis", "apoptosis", "normal"]) -> dict:
    # sample_paths = glob.glob(str(sample_json_dir / "*.json"))
    class2samples = {}
    for subfolder in class_subfolders:
        class2samples[subfolder] = []
        sample_paths = glob.glob(str(sample_json_dir / subfolder / "*.json"))
        for sample_path in sample_paths:
            sample = SingleCellStatic.load_single_cells_json(sample_path)
            class2samples[subfolder].append(sample)
    return class2samples


all_class2samples = None
all_class2sample_extra_info = {}
for sample_json_dir in sample_json_dirs:
    _class2samples = load_class2samples_from_json_dir(sample_json_dir)
    print(_class2samples)
    for class_name in _class2samples:
        # report how many samples loaded from the sample json dir
        print(f"Loaded {len(_class2samples[class_name])} annotated samples from {sample_json_dir / class_name}")

    if all_class2samples is None:
        all_class2samples = _class2samples
    for class_name in _class2samples:

        all_class2samples[class_name] += _class2samples[class_name]
        _extra_info =  [{"src_dir": sample_json_dir} for _ in range(len(_class2samples[class_name]))]
        if class_name not in all_class2sample_extra_info:
            all_class2sample_extra_info[class_name] = _extra_info
        else:
            all_class2sample_extra_info[class_name] += _extra_info
           

In [None]:
all_class2sample_extra_info

In [None]:
len(all_class2samples["mitosis"])

Automatically prepare normal samples

require tracking done

In [None]:
# get all scs from class_samples not in normal class
exclude_scs = []
total_non_normal_samples = 0
for class_name, samples in all_class2samples.items():
    if class_name != "normal":
        for sample in samples:
            exclude_scs.extend(sample)
            total_non_normal_samples += 1

exclude_scs = set(exclude_scs)
exclude_scs_ids = {str(sc.id) for sc in exclude_scs}

In [None]:
# from livecell_tracker.core.sct_operator import create_scs_edit_viewer
# sct_operator = create_scs_edit_viewer(exclude_scs, img_dataset = list(exclude_scs)[0].img_dataset)

load all scs

In [None]:
import json
from livecell_tracker.core.single_cell import SingleCellTrajectoryCollection
from livecell_tracker.track.sort_tracker_utils import (
    track_SORT_bbox_from_scs
)

all_scs_json_path = ["./datasets/test_scs_EBSS_starvation/XY1/single_cells.json", "./datasets/test_scs_EBSS_starvation/XY16/single_cells.json"]
# all_scs_json_path = "./datasets/test_scs_EBSS_starvation/XY16/tmp_corrected_scs.json"
sctc = SingleCellTrajectoryCollection()
for json_path in all_scs_json_path:
    _scs = SingleCellStatic.load_single_cells_json(json_path)
    tmp_sctc = track_SORT_bbox_from_scs(_scs, raw_imgs=_scs[0].img_dataset, min_hits=3, max_age=3)
    tids = set(sctc.get_all_tids())
    if len(tids) != 0:
        max_tid = max(tids)
    else:
        max_tid = 0
    for tid, traj in tmp_sctc:
        traj.meta["src_dir"] = json_path
        traj.track_id = tid + max_tid + 1
        sctc.add_trajectory(traj)
        traj_scs = traj.get_all_scs()
        for sc in traj_scs:
            sc.meta["src_dir"] = json_path
    del tmp_sctc

all_scs = SingleCellStatic.load_single_cells_jsons(all_scs_json_path)

In [None]:

# with open("./EBSS_starvation_24h_xy16_annotation/single_cell_trajectory_collection.json", "r") as file:
#     json_dict = json.load(file)
# sctc = SingleCellTrajectoryCollection().load_from_json_dict(json_dict)


In [None]:
# set numpy seed
seed = 0
np.random.seed(seed)

objective_sample_num = total_non_normal_samples * 10

normal_frame_len_range = (3, 10)
counter = 0
normal_samples = []
normal_samples_extra_info = []
max_trial_counter = 100000
while counter < objective_sample_num and max_trial_counter > 0:
    # randomly select a sct from sctc
    # generate a list of scs
    track_id = np.random.choice(list(sctc.track_id_to_trajectory.keys()))  
    sct = sctc.get_trajectory(track_id)
    # randomly select a length
    frame_len = np.random.randint(*normal_frame_len_range)
    # generate a sample
    times = list(sct.timeframe_to_single_cell.keys())
    times = sorted(times)
    if len(times) <= frame_len:
        continue
    start_idx = np.random.randint(0, len(times) - frame_len)
    start_time = times[start_idx]
    end_time = times[start_idx + frame_len - 1]

    sub_sct = sct.subsct(start_time, end_time)

    is_some_sc_in_exclude_scs = False
    for time, sc in sub_sct.timeframe_to_single_cell.items():
        # print("sc.id:", sc.id, type(sc.id))
        if str(sc.id) in exclude_scs_ids:
            is_some_sc_in_exclude_scs = True
            break
    if is_some_sc_in_exclude_scs:
        print("some sc in exclude scs")
        continue
    
    new_sample = []
    for time, sc in sub_sct.timeframe_to_single_cell.items():
        new_sample.append(sc)
    normal_samples.append(new_sample)
    normal_samples_extra_info.append({"src_dir": sub_sct.get_all_scs()[0].meta["src_dir"]})
    counter += 1
    max_trial_counter -= 1

normal_samples[:2]

In [None]:
all_class2samples["normal"].extend(normal_samples)
all_class2sample_extra_info["normal"].extend(normal_samples_extra_info)

In [None]:
len(all_class2samples["normal"]), len(all_class2sample_extra_info["normal"])

## Prepare videos and annotations for MMDetection

In [None]:
classes = all_class2samples.keys()
classes

In [None]:
from livecell_tracker.core.utils import gray_img_to_rgb, rgb_img_to_gray
from livecell_tracker.preprocess.utils import normalize_img_to_uint8

In [None]:
from livecell_tracker.track.classify_utils import video_frames_and_masks_from_sample, combine_video_frames_and_masks

In [45]:
from typing import List
import cv2
import numpy as np
import pandas as pd

from livecell_tracker.core.sc_video_utils import gen_mp4_from_frames, gen_samples_df, gen_samples_mp4s

ver = "10-drop-div"
# ver = "-test"
DROP_MITOSIS_DIV = True

data_dir = Path(f'notebook_results/mmaction_train_data_v{ver}')
class_labels = ['mitosis', 'apoptosis', 'normal']
class_label = "mitosis"
frame_types = ["video", "mask", "combined"]
fps = 3

padding_pixels = [0, 20, 40, 50, 100, 200, 400]



# split train and test data

# get #samples from all_class2samples
_split = 0.8

train_class2samples = {}
test_class2samples = {}
train_class2sample_extra_info = {}
test_class2sample_extra_info = {}

# randomize train and test data


for key in all_class2samples.keys():
    randomized_indices = np.random.permutation(len(all_class2samples[key])).astype(int)
    split_idx = int(len(all_class2samples[key]) * _split)
    _train_indices = randomized_indices[:split_idx]
    _test_indices = randomized_indices[split_idx:]
    train_class2samples[key] = np.array(all_class2samples[key], dtype=object)[_train_indices]
    test_class2samples[key] = np.array(all_class2samples[key], dtype=object)[_test_indices]

    train_class2sample_extra_info[key] = np.array(all_class2sample_extra_info[key], dtype=object)[_train_indices]
    test_class2sample_extra_info[key] = np.array(all_class2sample_extra_info[key], dtype=object)[_test_indices]



In [None]:
len(train_class2samples["normal"]), len(test_class2samples["normal"])

In [None]:
len(train_class2samples["mitosis"]), len(test_class2samples["mitosis"])

In [None]:
video_frames_and_masks_from_sample(train_class2samples["normal"][6])[0][0].shape
# train_class2samples["normal"][6][1].show_panel()

In [None]:
import importlib
import livecell_tracker
importlib.reload(livecell_tracker.track.classify_utils)

In [None]:
idx_to_check = 6
video_frames, video_frame_masks = video_frames_and_masks_from_sample(train_class2samples["normal"][idx_to_check], padding_pixels=0)
print("video frames dtype:", video_frames[0].dtype)
print("video frames shape:", video_frames[0].shape)
print("video frame masks dtype:", video_frame_masks[0].dtype)
print("video frame masks shape:", video_frame_masks[0].shape)
combined_frames = livecell_tracker.track.classify_utils.combine_video_frames_and_masks(video_frames, video_frame_masks, edt_transform=True)
combined_frames = np.array(combined_frames).astype(np.uint8)
# combined_frames = np.maximum(combined_frames - 1, 0).astype(np.uint8)
print("combined_frames shape: ", combined_frames[0].shape)
gen_mp4_from_frames(combined_frames, "./test_video_output.mp4", fps=1)

Visually check the generated frames' values

In [None]:
# channel = 2
# plt.imshow(combined_frames[0][..., channel])
# combined_frames[1][..., channel].max(), combined_frames[1][..., 0].shape

In [None]:
np.array(combined_frames).flatten().min()

Drop the cell divison part for easier inference durign testing

In [42]:
drop_div_keys = ["mitosis"]

def drop_sample_div(sample: List[SingleCellStatic]):
    """remove scs in samples where at the same timepoint, there are >=2 scs"""
    sc_by_time = {}
    for sc in sample:
        if sc.timeframe not in sc_by_time:
            sc_by_time[sc.timeframe] = []
        sc_by_time[sc.timeframe].append(sc)

    new_sample = []
    for time, scs in sc_by_time.items():
        if len(scs) == 1:
            new_sample.append(scs[0])
    return new_sample

def check_one_sc_at_time(sample: List[SingleCellStatic]):
    """check if there is only one sc at each timepoint"""
    times = set()
    for sc in sample:
        if sc.timeframe in times:
            return False
        times.add(sc.timeframe)
    return True

if DROP_MITOSIS_DIV:
    def drop_div_keys(class2samples, tar_keys=["mitosis"]):
        class2samples = class2samples.copy()
        for key in tar_keys:
            tmp_samples = []
            key_samples = class2samples[key]
            for sample in key_samples:
                tmp_samples.append(drop_sample_div(sample))
            class2samples[key] = tmp_samples
            assert all([check_one_sc_at_time(sample) for sample in class2samples[key]]), "there is more than one sc at the same timepoint"
        return class2samples
    train_class2samples = drop_div_keys(train_class2samples)
    test_class2samples = drop_div_keys(test_class2samples)



In [43]:

# # for debug
# train_class2samples = {key: value[:5] for key, value in all_class2samples.items()}
# test_class2samples = {key: value[:5] for key, value in all_class2samples.items()}
# padding_pixels = [20]


train_sample_info_df = gen_samples_df(train_class2samples, train_class2sample_extra_info, data_dir, class_labels, padding_pixels, frame_types, fps, prefix="train")
test_sample_info_df = gen_samples_df(test_class2samples, test_class2sample_extra_info, data_dir, class_labels, padding_pixels, frame_types, fps, prefix="test")

train_sample_info_df.to_csv(data_dir/f'train_data.txt', index=False, header=False, sep=' ', )
test_sample_info_df.to_csv(data_dir/f'test_data.txt', index=False, header=False, sep=' ', )

for selected_frame_type in frame_types:
    train_df_path = data_dir/f'mmaction_train_data_{selected_frame_type}.txt'
    train_selected_frame_type_df = train_sample_info_df[train_sample_info_df["frame_type"] == selected_frame_type]
    train_selected_frame_type_df = train_selected_frame_type_df.reset_index(drop=True)
    train_selected_frame_type_df = train_selected_frame_type_df[["path", "label_index"]]
    train_selected_frame_type_df.to_csv(train_df_path, index=False, header=False, sep=' ')
    
    test_df_path = data_dir/f'mmaction_test_data_{selected_frame_type}.txt'
    test_selected_frame_type_df = test_sample_info_df[test_sample_info_df["frame_type"] == selected_frame_type]
    test_selected_frame_type_df = test_selected_frame_type_df[["path", "label_index"]]
    test_selected_frame_type_df = test_selected_frame_type_df.reset_index(drop=True)
    test_selected_frame_type_df.to_csv(test_df_path, index=False, header=False, sep=' ')


# # the follwing code generates v1-v7 test data. The issue is that some of test data shows up in train data, through different padding values.
# data_df_path = data_dir/'all_data.txt'
# sample_df = gen_samples_df(data_dir, class_labels, padding_pixels, frame_types, fps)
# sample_df.to_csv(data_df_path, index=False, header=False, sep=' ')
# for selected_frame_type in frame_types:
#     selected_frame_type_df = sample_df[sample_df["frame_type"] == selected_frame_type]
#     selected_frame_type_df = selected_frame_type_df.reset_index(drop=True)
#     train_df_path = data_dir/f'train_data_{selected_frame_type}.txt'
#     test_df_path = data_dir/f'test_data_{selected_frame_type}.txt'
#     train_df = selected_frame_type_df.sample(frac=0.8, random_state=0, replace=False)
#     test_df = selected_frame_type_df.drop(train_df.index, inplace=False)

#     # only keep the path and label_index columns
#     train_df = train_df[["path", "label_index"]]
#     test_df = test_df[["path", "label_index"]]

#     train_df.to_csv(train_df_path, index=False, header=False, sep=' ')
#     test_df.to_csv(test_df_path, index=False, header=False, sep=' ')


  0%|          | 0/22 [00:00<?, ?it/s]

100%|██████████| 22/22 [00:41<00:00,  1.86s/it]
100%|██████████| 22/22 [00:19<00:00,  1.10it/s]
100%|██████████| 22/22 [00:21<00:00,  1.02it/s]
100%|██████████| 22/22 [00:23<00:00,  1.06s/it]
100%|██████████| 22/22 [00:22<00:00,  1.01s/it]
100%|██████████| 22/22 [00:26<00:00,  1.22s/it]
100%|██████████| 22/22 [00:39<00:00,  1.80s/it]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
100%|██████████| 224/224 [01:13<00:00,  3.06it/s]
100%|██████████| 224/224 [00:36<00:00,  6.17it/s]
100%|██████████| 224/224 [00:38<00:00,  5.83it/s]
100%|██████████| 224/224 [00:39<00:00,  5.65it/s]
100%|██████████| 224/224 [00:49<00:00,  4.53it/s]
100%|██████████| 224/224 [01:22<00:00,  2.72it/s]
100%|██████████| 224/224 [03:11<00:00,  1.17it/s]
100%|██████████| 6/6 [00:18<00:00,  3.14s/it]
100%|██████████| 6/6 [00:09<00:00,  1.52s/it]
100%|██████████| 6/6 [00:09<00:00,  1.59s/it]
100%|██████████| 6/6 [00:09<00:00,  1.62s/i

In [None]:
train_class2samples

Check the videos

In [47]:

video_paths = list(Path(data_dir/'videos').glob('*.mp4'))

Due to a `decord` package [issue](https://github.com/dmlc/decord/issues/150), to use mmaction2 we must check if the videos can be loaded by `decord` correctly.

In [49]:
import decord
for path in video_paths:
# for path in ["./notebook_results/train_normal_6_raw_padding-0.mp4"]:
# for path in ["./test_video_output.mp4"]:
    reader = decord.VideoReader(str(path))
    reader.seek(0)
    imgs = list()
    frame_inds = range(0, len(reader))
    for idx in frame_inds:
        reader.seek(idx)
        frame = reader.next()
        imgs.append(frame.asnumpy())
        frame = frame.asnumpy()

        num_channels = frame.shape[-1]
        if num_channels != 3:
            print("invalid video for decord (https://github.com/dmlc/decord/issues/150): ", path)
            break
        # fig, axes = plt.subplots(1, num_channels, figsize=(20, 10))
        # for i in range(num_channels):
        #     axes[i].imshow(frame[:, :, i])
        # plt.show()
    del reader

invalid video for decord (https://github.com/dmlc/decord/issues/150):  notebook_results\mmaction_train_data_v10-drop-div\videos\train_normal_212_combined_padding-0.mp4
invalid video for decord (https://github.com/dmlc/decord/issues/150):  notebook_results\mmaction_train_data_v10-drop-div\videos\train_normal_212_mask_padding-0.mp4
invalid video for decord (https://github.com/dmlc/decord/issues/150):  notebook_results\mmaction_train_data_v10-drop-div\videos\train_normal_212_raw_padding-0.mp4
invalid video for decord (https://github.com/dmlc/decord/issues/150):  notebook_results\mmaction_train_data_v10-drop-div\videos\train_normal_43_combined_padding-0.mp4
invalid video for decord (https://github.com/dmlc/decord/issues/150):  notebook_results\mmaction_train_data_v10-drop-div\videos\train_normal_43_mask_padding-0.mp4
invalid video for decord (https://github.com/dmlc/decord/issues/150):  notebook_results\mmaction_train_data_v10-drop-div\videos\train_normal_43_raw_padding-0.mp4


In [None]:
decord.__version__

check if videos can be loaded by cv2 correctly

In [None]:
import cv2

cap = cv2.VideoCapture("./test_video_output.mp4")

while True:
    ret, frame = cap.read()
    if not ret:
        break
    assert frame.shape[-1] == 3, "frame should be in RGB format"

cap.release()
cv2.destroyAllWindows()

In [None]:
# from sklearn.model_selection import train_test_split

# train_df_path = data_dir/'train_data.csv'
# test_df_path = data_dir/'test_data.csv'

# # split train and test from df
# train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
# train_df.to_csv(train_df_path, index=False, header=False, sep=' ')
# test_df.to_csv(test_df_path, index=False, header=False, sep=' ')
