# Prepare mitosis time series data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
import glob
from pathlib import Path
from PIL import Image, ImageSequence
from tqdm import tqdm
import os
import os.path
# from livecellx import segment
from livecellx import core
from livecellx.core import datasets
from livecellx.core.datasets import LiveCellImageDataset, SingleImageDataset
from skimage import measure
from livecellx.core import SingleCellTrajectory, SingleCellStatic

In [None]:
from livecellx.track.classify_utils import load_class2samples_from_json_dir, load_all_json_dirs
# sample_json_dir = Path("./EBSS_starvation_24h_xy16_annotation")

sample_json_dirs_v0 = [Path(r"./datasets/test_scs_EBSS_starvation/XY1/annotations"), Path(r"./datasets/test_scs_EBSS_starvation/XY16/annotations")]

round1_json_dirs = sample_json_dirs_v0 + [
    Path(r"D:\LiveCellTracker-dev\datasets\mitosis-annotations-2023\shiman_XY01\XY01"),
Path(r"D:\LiveCellTracker-dev\datasets\mitosis-annotations-2023\shiman_XY09\XY09"),
Path(r"D:\LiveCellTracker-dev\datasets\mitosis-annotations-2023\shiman_XY10\XY10"),
Path(r"D:\LiveCellTracker-dev\datasets\mitosis-annotations-2023\Yajushi\tifs_CFP_A549-VIM_lessThan24hr_NoTreat_NA_YL_Ti2e_2022-10-19\XY1\annotations"),
]

round2_json_dirs = [
Path(r"../datasets/mitosis-annotations-2023/shiman_CXA_high_density/C0.5^4/"),
Path(r"../datasets/mitosis-annotations-2023/shiman_CXA_high_density/C0.75^4/"),
Path(r"../datasets/mitosis-annotations-2023/shiman_CXA_high_density/C10^3/"),
Path(r"../datasets/mitosis-annotations-2023/shiman_CXA_high_density/C10^4/")
] + [
    Path(f"../datasets/mitosis-annotations-2023/Gaohan_tifs_CFP_A549-VIM_lessThan24hr_NoTreat_NA_YL_Ti2e_2022-10-19/XY{pos}/annotations") for pos in range(4, 14)
]

sample_json_dirs = sample_json_dirs_v0 + round1_json_dirs + round2_json_dirs
all_class2samples, all_class2sample_extra_info = load_all_json_dirs(sample_json_dirs)
           

In [None]:
# all_class2samples, all_class2sample_extra_info

In [None]:
len(all_class2samples["mitosis"]), len(all_class2samples["apoptosis"]), len(all_class2samples["normal"])

In [None]:
num_zero_len_samples = 0
for key in all_class2samples.keys():
    for sample in all_class2samples[key]:
        if len(sample) == 0:
            num_zero_len_samples += 1
print("num_zero_len_samples: ", num_zero_len_samples)

Automatically prepare normal samples

require tracking done

In [None]:
# get all scs from class_samples not in normal class
exclude_scs = []
total_non_normal_samples = 0
for class_name, samples in all_class2samples.items():
    if class_name != "normal":
        for sample in samples:
            exclude_scs.extend(sample)
            total_non_normal_samples += 1

exclude_scs = set(exclude_scs)
exclude_scs_ids = {str(sc.id) for sc in exclude_scs}

In [None]:
# from livecellx.core.sct_operator import create_scs_edit_viewer
# sct_operator = create_scs_edit_viewer(exclude_scs, img_dataset = list(exclude_scs)[0].img_dataset)

load all single cells, including mitosis and normal ones, for further generating normal samples automatically

In [None]:
import json
from livecellx.core.single_cell import SingleCellTrajectoryCollection
from livecellx.track.sort_tracker_utils import (
    track_SORT_bbox_from_scs
)

all_scs_json_path = ["./datasets/test_scs_EBSS_starvation/XY1/single_cells.json", "./datasets/test_scs_EBSS_starvation/XY16/single_cells.json"]
# all_scs_json_path = "./datasets/test_scs_EBSS_starvation/XY16/tmp_corrected_scs.json"
sctc = SingleCellTrajectoryCollection()
for json_path in all_scs_json_path:
    print("json path:", json_path)
    _scs = SingleCellStatic.load_single_cells_json(json_path)
    tmp_sctc = track_SORT_bbox_from_scs(_scs, raw_imgs=_scs[0].img_dataset, min_hits=3, max_age=3)
    tids = set(sctc.get_all_tids())
    if len(tids) != 0:
        max_tid = max(tids)
    else:
        max_tid = 0
    for tid, traj in tmp_sctc:
        traj.meta["src_dir"] = json_path
        traj.track_id = tid + max_tid + 1
        sctc.add_trajectory(traj)
        traj_scs = traj.get_all_scs()
        for sc in traj_scs:
            sc.meta["src_dir"] = json_path
    del tmp_sctc

all_scs = SingleCellStatic.load_single_cells_jsons(all_scs_json_path)

In [None]:

# with open("./EBSS_starvation_24h_xy16_annotation/single_cell_trajectory_collection.json", "r") as file:
#     json_dict = json.load(file)
# sctc = SingleCellTrajectoryCollection().load_from_json_dict(json_dict)


In [None]:
# set numpy seed
seed = 0
np.random.seed(seed)

objective_sample_num = total_non_normal_samples * 10

normal_frame_len_range = (3, 10)
counter = 0
normal_samples = []
normal_samples_extra_info = []
max_trial_counter = 100000
while counter < objective_sample_num and max_trial_counter > 0:
    # randomly select a sct from sctc
    # generate a list of scs
    track_id = np.random.choice(list(sctc.track_id_to_trajectory.keys()))  
    sct = sctc.get_trajectory(track_id)
    # randomly select a length
    frame_len = np.random.randint(*normal_frame_len_range)
    # generate a sample
    times = list(sct.timeframe_to_single_cell.keys())
    times = sorted(times)
    if len(times) <= frame_len:
        continue
    start_idx = np.random.randint(0, len(times) - frame_len)
    start_time = times[start_idx]
    end_time = times[start_idx + frame_len - 1]

    sub_sct = sct.subsct(start_time, end_time)

    is_some_sc_in_exclude_scs = False
    for time, sc in sub_sct.timeframe_to_single_cell.items():
        # print("sc.id:", sc.id, type(sc.id))
        if str(sc.id) in exclude_scs_ids:
            is_some_sc_in_exclude_scs = True
            break
    if is_some_sc_in_exclude_scs:
        print("some sc in the exclude scs list")
        continue
    
    new_sample = []
    for time, sc in sub_sct.timeframe_to_single_cell.items():
        new_sample.append(sc)
    normal_samples.append(new_sample)
    normal_samples_extra_info.append({"src_dir": sub_sct.get_all_scs()[0].meta["src_dir"]})
    counter += 1
    max_trial_counter -= 1

normal_samples[:2]

In [None]:
all_class2samples["normal"].extend(normal_samples)
all_class2sample_extra_info["normal"].extend(normal_samples_extra_info)

In [None]:
len(all_class2samples["normal"]), len(all_class2sample_extra_info["normal"])

Add start and end time to all_class2sample_extra_info

In [None]:
for class_name, samples in all_class2samples.items():
    print(class_name, len(samples))
    class_extra_infos = all_class2sample_extra_info[class_name]
    for sample_idx, sample in enumerate(samples):
        if len(sample) == 0:
            continue
        sample_extra_info = class_extra_infos[sample_idx]
        min_time = None
        max_time = None
        for sc in sample:
            if min_time is None or sc.timeframe < min_time:
                min_time = sc.timeframe
            if max_time is None or sc.timeframe > max_time:
                max_time = sc.timeframe
        sample_extra_info["start_time"] = min_time
        sample_extra_info["end_time"] = max_time
        sample_extra_info["first_sc_id"] = sample[0].id


## Prepare videos and annotations for MMDetection

In [None]:
classes = all_class2samples.keys()
classes

In [None]:
from livecellx.core.utils import gray_img_to_rgb, rgb_img_to_gray
from livecellx.preprocess.utils import normalize_img_to_uint8

In [None]:
from livecellx.track.classify_utils import video_frames_and_masks_from_sample, combine_video_frames_and_masks

In [None]:
from typing import List
import cv2
import numpy as np
import pandas as pd

from livecellx.core.sc_video_utils import gen_mp4_from_frames, gen_class2sample_samples, gen_samples_mp4s

# ver = "10-st" # single trajectory ver
# ver = "test" # single trajectory ver
# ver = "11-st-run0"
# MAKE_SINGLE_CELL_TRAJ_SAMPLES = True
# DROP_MITOSIS_DIV = False

# ver = "10-drop-div"
# DROP_MITOSIS_DIV = True
# ver = "-test"

# ver = "11-drop-div"
# MAKE_SINGLE_CELL_TRAJ_SAMPLES = False
# DROP_MITOSIS_DIV = True

# ver = "12-st"
# MAKE_SINGLE_CELL_TRAJ_SAMPLES = True
# DROP_MITOSIS_DIV = False

# ver = "12-drop-div"
# MAKE_SINGLE_CELL_TRAJ_SAMPLES = False
# DROP_MITOSIS_DIV = True

# ver = "12-all"
# MAKE_SINGLE_CELL_TRAJ_SAMPLES = False
# DROP_MITOSIS_DIV = False

ver = "test-all"
MAKE_SINGLE_CELL_TRAJ_SAMPLES = False
DROP_MITOSIS_DIV = False

In [None]:

data_dir = Path(f'notebook_results/mmaction_train_data_v{ver}')
class_labels = ['mitosis', 'apoptosis', 'normal']
class_label = "mitosis"
frame_types = ["video", "mask", "combined"]
fps = 3

# 1 instead of 0 to prevent the decord (used by mmdetection) python package error
padding_pixels = [1, 20, 40, 50, 100, 200, 400]



# split train and test data

# get #samples from all_class2samples
_split = 0.8

train_class2samples = {}
test_class2samples = {}
train_class2sample_extra_info = {}
test_class2sample_extra_info = {}

# randomize train and test data


for key in all_class2samples.keys():
    randomized_indices = np.random.permutation(len(all_class2samples[key])).astype(int)
    split_idx = int(len(all_class2samples[key]) * _split)
    _train_indices = randomized_indices[:split_idx]
    _test_indices = randomized_indices[split_idx:]
    train_class2samples[key] = np.array(all_class2samples[key], dtype=object)[_train_indices]
    test_class2samples[key] = np.array(all_class2samples[key], dtype=object)[_test_indices]

    train_class2sample_extra_info[key] = np.array(all_class2sample_extra_info[key], dtype=object)[_train_indices]
    test_class2sample_extra_info[key] = np.array(all_class2sample_extra_info[key], dtype=object)[_test_indices]



In [None]:
len(train_class2samples["normal"]), len(test_class2samples["normal"])

In [None]:
len(train_class2samples["mitosis"]), len(test_class2samples["mitosis"])

In [None]:
video_frames_and_masks_from_sample(train_class2samples["normal"][6])[0][0].shape
# train_class2samples["normal"][6][1].show_panel()

In [None]:
import importlib
import livecellx
importlib.reload(livecellx.track.classify_utils)

In [None]:
idx_to_check = 6
video_frames, video_frame_masks = video_frames_and_masks_from_sample(train_class2samples["normal"][idx_to_check], padding_pixels=0)
print("video frames dtype:", video_frames[0].dtype)
print("video frames shape:", video_frames[0].shape)
print("video frame masks dtype:", video_frame_masks[0].dtype)
print("video frame masks shape:", video_frame_masks[0].shape)
combined_frames = livecellx.track.classify_utils.combine_video_frames_and_masks(video_frames, video_frame_masks, edt_transform=True)
combined_frames = np.array(combined_frames).astype(np.uint8)
# combined_frames = np.maximum(combined_frames - 1, 0).astype(np.uint8)
print("combined_frames shape: ", combined_frames[0].shape)
gen_mp4_from_frames(combined_frames, "./test_video_output.mp4", fps=1)

Visually check the generated frames' values

In [None]:
# channel = 2
# plt.imshow(combined_frames[0][..., channel])
# combined_frames[1][..., channel].max(), combined_frames[1][..., 0].shape

In [None]:
np.array(combined_frames).flatten().min()

Make single cell trajectories only (ONE cell per time frame)

In [None]:
from typing import Dict
from livecellx.track.data_prep_utils import check_one_sc_at_time
from livecellx.track.data_prep_utils import make_one_cell_per_timeframe_for_class2samples, make_one_cell_per_timeframe_helper, make_one_cell_per_timeframe_samples



sample = train_class2samples["mitosis"][0]

In [None]:
[sc.timeframe for sc in sample]

In [None]:
len(make_one_cell_per_timeframe_samples(sample))

In [None]:
if MAKE_SINGLE_CELL_TRAJ_SAMPLES:
    train_class2samples, train_class2sample_extra_info = make_one_cell_per_timeframe_for_class2samples(train_class2samples, train_class2sample_extra_info)
    test_class2samples, test_class2sample_extra_info = make_one_cell_per_timeframe_for_class2samples(test_class2samples, test_class2sample_extra_info)

Drop the cell divison part for easier inference durign testing

In [None]:
from livecellx.track.data_prep_utils import drop_multiple_cell_frames_in_samples

if DROP_MITOSIS_DIV:
    train_class2samples = drop_multiple_cell_frames_in_samples(train_class2samples)
    test_class2samples = drop_multiple_cell_frames_in_samples(test_class2samples)


In [None]:
for key, val in train_class2samples.items():
    assert len(val) == len(train_class2sample_extra_info[key]), f"key: {key}, len(val): {len(val)}, len(train_class2sample_extra_info[key]): {len(train_class2sample_extra_info[key])}"

In [None]:
train_class2sample_extra_info["mitosis"]

In [None]:
import importlib
import livecellx
import livecellx.core.sc_video_utils
importlib.reload(livecellx.core.sc_video_utils)

# # for debug
# test_sample_num = 3
# padding_pixels = [1, 20]
# train_class2samples = {key: value[:test_sample_num] for key, value in all_class2samples.items()}
# test_class2samples = {key: value[:test_sample_num] for key, value in all_class2samples.items()}
# train_class2sample_extra_info = {key: value[:test_sample_num] for key, value in all_class2sample_extra_info.items()}
# test_class2sample_extra_info = {key: value[:test_sample_num] for key, value in all_class2sample_extra_info.items()}

# padding_pixels = [20]

train_sample_info_df = livecellx.core.sc_video_utils.gen_class2sample_samples(
    train_class2samples,
    train_class2sample_extra_info,
    data_dir,
    class_labels,
    padding_pixels=padding_pixels,
    frame_types=frame_types,
    fps=fps,
    prefix="train",
)
test_sample_info_df = livecellx.core.sc_video_utils.gen_class2sample_samples(
    test_class2samples,
    test_class2sample_extra_info,
    data_dir,
    class_labels,
    padding_pixels=padding_pixels,
    frame_types=frame_types,
    fps=fps,
    prefix="test",
)


In [None]:
train_sample_info_df[:2]

In [None]:

train_sample_info_df.to_csv(
    data_dir / f"train_data.txt",
    index=False,
    header=True,
    sep=" ",
)
test_sample_info_df.to_csv(
    data_dir / f"test_data.txt",
    index=False,
    header=True,
    sep=" ",
)

mmaction_df_paths = []
for selected_frame_type in frame_types:
    train_df_path = data_dir / f"mmaction_train_data_{selected_frame_type}.txt"
    train_selected_frame_type_df = train_sample_info_df[train_sample_info_df["frame_type"] == selected_frame_type]
    train_selected_frame_type_df = train_selected_frame_type_df.reset_index(drop=True)
    train_selected_frame_type_df = train_selected_frame_type_df[["path", "label_index"]]
    train_selected_frame_type_df.to_csv(train_df_path, index=False, header=False, sep=" ")

    test_df_path = data_dir / f"mmaction_test_data_{selected_frame_type}.txt"
    test_selected_frame_type_df = test_sample_info_df[test_sample_info_df["frame_type"] == selected_frame_type]
    test_selected_frame_type_df = test_selected_frame_type_df[["path", "label_index"]]
    test_selected_frame_type_df = test_selected_frame_type_df.reset_index(drop=True)
    test_selected_frame_type_df.to_csv(test_df_path, index=False, header=False, sep=" ")

    mmaction_df_paths.append(train_df_path)
    mmaction_df_paths.append(test_df_path)


# # The follwing code generates v1-v7 test data. The issue is that some of test data shows up in train data, through different padding values.
# data_df_path = data_dir/'all_data.txt'
# sample_df = gen_samples_df(data_dir, class_labels, padding_pixels, frame_types, fps)
# sample_df.to_csv(data_df_path, index=False, header=False, sep=' ')
# for selected_frame_type in frame_types:
#     selected_frame_type_df = sample_df[sample_df["frame_type"] == selected_frame_type]
#     selected_frame_type_df = selected_frame_type_df.reset_index(drop=True)
#     train_df_path = data_dir/f'train_data_{selected_frame_type}.txt'
#     test_df_path = data_dir/f'test_data_{selected_frame_type}.txt'
#     train_df = selected_frame_type_df.sample(frac=0.8, random_state=0, replace=False)
#     test_df = selected_frame_type_df.drop(train_df.index, inplace=False)

#     # only keep the path and label_index columns
#     train_df = train_df[["path", "label_index"]]
#     test_df = test_df[["path", "label_index"]]

#     train_df.to_csv(train_df_path, index=False, header=False, sep=' ')
#     test_df.to_csv(test_df_path, index=False, header=False, sep=' ')


In [None]:
train_class2samples

Check the videos

In [None]:

video_paths = list(Path(data_dir/'videos').glob('*.mp4'))

Due to a `decord` package [issue](https://github.com/dmlc/decord/issues/150), to use mmaction2 we must check if the videos can be loaded by `decord` correctly.

In [None]:
import decord
decord.__version__

In [None]:
import decord
invalid_decord_paths = []
for path in video_paths:
# for path in ["./notebook_results/train_normal_6_raw_padding-0.mp4"]:
# for path in ["./test_video_output.mp4"]:
    reader = decord.VideoReader(str(path))
    reader.seek(0)
    imgs = list()
    frame_inds = range(0, len(reader))
    for idx in frame_inds:
        reader.seek(idx)
        frame = reader.next()
        imgs.append(frame.asnumpy())
        frame = frame.asnumpy()

        num_channels = frame.shape[-1]
        if num_channels != 3:
            print("invalid video for decord (https://github.com/dmlc/decord/issues/150): ", path)
            invalid_decord_paths.append(path)
            break
        # fig, axes = plt.subplots(1, num_channels, figsize=(20, 10))
        # for i in range(num_channels):
        #     axes[i].imshow(frame[:, :, i])
        # plt.show()
    del reader

Remove "invalid" videos (cannot be read by decord) from mmdetection 

In [None]:
# extract file names from invalid decord paths
invalid_decord_filenames = set([os.path.basename(path) for path in invalid_decord_paths])

for df_path in mmaction_df_paths:
    _df = pd.read_csv(df_path, sep=" ", header=None)
    # remove all the rows with column "path" in invalid_decord_filenames
    filtered_df = _df[~_df[0].isin(invalid_decord_filenames)]

    df_filename = os.path.basename(df_path)
    # summarize the number of samples for the file
    print(f"df_path: {df_filename}, #filtered: {_df.shape[0] - filtered_df.shape[0]}, original df shape: {_df.shape}, filtered df shape: {filtered_df.shape}")

    # save to the disk
    filtered_df.to_csv(df_path, index=False, header=False, sep=" ")


check if videos can be loaded by cv2 correctly

In [None]:
import cv2

cap = cv2.VideoCapture("./test_video_output.mp4")

while True:
    ret, frame = cap.read()
    if not ret:
        break
    assert frame.shape[-1] == 3, "frame should be in RGB format"

cap.release()
cv2.destroyAllWindows()

In [None]:
# from sklearn.model_selection import train_test_split

# train_df_path = data_dir/'train_data.csv'
# test_df_path = data_dir/'test_data.csv'

# # split train and test from df
# train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
# train_df.to_csv(train_df_path, index=False, header=False, sep=' ')
# test_df.to_csv(test_df_path, index=False, header=False, sep=' ')
