# Prepare mitosis time series data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
import glob
from pathlib import Path
from PIL import Image, ImageSequence
from tqdm import tqdm
import os
import os.path
# from livecell_tracker import segment
from livecell_tracker import core
from livecell_tracker.core import datasets
from livecell_tracker.core.datasets import LiveCellImageDataset, SingleImageDataset
from skimage import measure
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic

In [None]:
# sample_json_dir = Path("./EBSS_starvation_24h_xy16_annotation")

sample_json_dirs = [Path(r"./datasets/test_scs_EBSS_starvation/XY1/annotations"), Path(r"./datasets/test_scs_EBSS_starvation/XY16/annotations")]

def load_class2samples_from_json_dir(sample_json_dir: Path, class_subfolders = ["mitosis", "apoptosis", "normal"]) -> dict:
    # sample_paths = glob.glob(str(sample_json_dir / "*.json"))
    class2samples = {}
    for subfolder in class_subfolders:
        class2samples[subfolder] = []
        sample_paths = glob.glob(str(sample_json_dir / subfolder / "*.json"))
        for sample_path in sample_paths:
            sample = SingleCellStatic.load_single_cells_json(sample_path)
            class2samples[subfolder].append(sample)
    return class2samples


all_class2samples = None
all_class2sample_extra_info = {}
for sample_json_dir in sample_json_dirs:
    _class2samples = load_class2samples_from_json_dir(sample_json_dir)
    print(_class2samples)
    for class_name in _class2samples:
        # report how many samples loaded from the sample json dir
        print(f"Loaded {len(_class2samples[class_name])} annotated samples from {sample_json_dir / class_name}")

    if all_class2samples is None:
        all_class2samples = _class2samples
    for class_name in _class2samples:

        all_class2samples[class_name] += _class2samples[class_name]
        _extra_info =  [{"src_dir": sample_json_dir} for _ in range(len(_class2samples[class_name]))]
        if class_name not in all_class2sample_extra_info:
            all_class2sample_extra_info[class_name] = _extra_info
        else:
            all_class2sample_extra_info[class_name] += _extra_info
           

In [None]:
all_class2sample_extra_info

In [None]:
len(all_class2samples["mitosis"])

Automatically prepare normal samples

require tracking done

In [None]:
# get all scs from class_samples not in normal class
exclude_scs = []
total_non_normal_samples = 0
for class_name, samples in all_class2samples.items():
    if class_name != "normal":
        for sample in samples:
            exclude_scs.extend(sample)
            total_non_normal_samples += 1

exclude_scs = set(exclude_scs)
exclude_scs_ids = {str(sc.id) for sc in exclude_scs}

In [None]:
# from livecell_tracker.core.sct_operator import create_scs_edit_viewer
# sct_operator = create_scs_edit_viewer(exclude_scs, img_dataset = list(exclude_scs)[0].img_dataset)

load all scs

In [None]:
all_scs_json_path = ["./datasets/test_scs_EBSS_starvation/XY1/single_cells.json", "./datasets/test_scs_EBSS_starvation/XY16/single_cells.json"]
# all_scs_json_path = "./datasets/test_scs_EBSS_starvation/XY16/tmp_corrected_scs.json"
all_scs = SingleCellStatic.load_single_cells_jsons(all_scs_json_path)

In [None]:
import json
from livecell_tracker.core.single_cell import SingleCellTrajectoryCollection
from livecell_tracker.track.sort_tracker_utils import (
    track_SORT_bbox_from_scs
)
# with open("./EBSS_starvation_24h_xy16_annotation/single_cell_trajectory_collection.json", "r") as file:
#     json_dict = json.load(file)
# sctc = SingleCellTrajectoryCollection().load_from_json_dict(json_dict)
sctc = track_SORT_bbox_from_scs(all_scs, raw_imgs=all_scs[0].img_dataset, min_hits=3, max_age=3)

In [None]:
# set numpy seed
seed = 0
np.random.seed(seed)

objective_sample_num = total_non_normal_samples * 10

normal_frame_len_range = (3, 10)
counter = 0
normal_samples = []
normal_samples_extra_info = []
max_trial_counter = 100000
while counter < objective_sample_num and max_trial_counter > 0:
    # randomly select a sct from sctc
    # generate a list of scs
    track_id = np.random.choice(list(sctc.track_id_to_trajectory.keys()))  
    sct = sctc.get_trajectory(track_id)
    # randomly select a length
    frame_len = np.random.randint(*normal_frame_len_range)
    # generate a sample
    times = list(sct.timeframe_to_single_cell.keys())
    times = sorted(times)
    if len(times) <= frame_len:
        continue
    start_idx = np.random.randint(0, len(times) - frame_len)
    start_time = times[start_idx]
    end_time = times[start_idx + frame_len - 1]

    sub_sct = sct.subsct(start_time, end_time)

    is_some_sc_in_exclude_scs = False
    for time, sc in sub_sct.timeframe_to_single_cell.items():
        # print("sc.id:", sc.id, type(sc.id))
        if str(sc.id) in exclude_scs_ids:
            is_some_sc_in_exclude_scs = True
            break
    if is_some_sc_in_exclude_scs:
        print("some sc in exclude scs")
        continue
    
    new_sample = []
    for time, sc in sub_sct.timeframe_to_single_cell.items():
        new_sample.append(sc)
    normal_samples.append(new_sample)
    normal_samples_extra_info.append({"src_dir": sub_sct.get_all_scs()[0].meta["src_json"]})
    counter += 1
    max_trial_counter -= 1

normal_samples[:2]

In [None]:
all_class2samples["normal"].extend(normal_samples)
all_class2sample_extra_info["normal"].extend(normal_samples_extra_info)

In [None]:
len(all_class2samples["normal"]), len(all_class2sample_extra_info["normal"])

## Prepare videos and annotations for MMDetection

In [None]:
classes = all_class2samples.keys()
classes

In [None]:
from livecell_tracker.core.utils import gray_img_to_rgb, rgb_img_to_gray
from livecell_tracker.preprocess.utils import normalize_img_to_uint8

In [None]:
from livecell_tracker.track.classify_utils import video_frames_and_masks_from_sample, combine_video_frames_and_masks

In [None]:
from typing import List
import cv2
import numpy as np
import pandas as pd

def gen_mp4_from_frames(video_frames, output_file, fps):
    # Define the output video file name and properties
    frame_size = video_frames[0].shape[:2][::-1]  # reverse the order of width and height
    # Create a VideoWriter object
    # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fourcc = cv2.VideoWriter_fourcc(*'H265')
    out = cv2.VideoWriter(str(output_file), fourcc, fps, frame_size, isColor=True)
    # Write each frame to the output video
    for frame in video_frames:
        out.write(frame)
    out.release()

def gen_samples_mp4s(sc_samples: List[List[SingleCellStatic]], samples_info_list, class_label, output_dir, fps = 3, padding_pixels=50, prefix=""):
    """
    Generate mp4 videos and masks from a list of SingleCellStatic samples.
    Args:
        sc_samples: A list of SingleCellStatic samples.
        sample_info_list: A list of dictionaries containing the information of the samples.
        class_label: A string representing the class label of the samples.
        output_dir: A Path object representing the directory to save the generated videos and masks.
        fps: An integer representing the frames per second of the generated videos.
        padding_pixels: An integer representing the number of pixels to pad around the cells in the generated videos and masks.
    Returns:
        A dictionary containing the file paths of the generated videos, masks, and combined videos.
    """
    res_paths = {
        "video": [],
        "mask": [],
        "combined": []
    }
    res_extra_info = []
    for i, sample in enumerate(sc_samples):
        output_file = output_dir / (f'{prefix}_{class_label}_{i}_raw_padding-{padding_pixels}.mp4')
        mask_output_file = output_dir / (f'{prefix}_{class_label}_{i}_mask_padding-{padding_pixels}.mp4')
        combined_output_file = output_dir / (f'{prefix}_{class_label}_{i}_combined_padding-{padding_pixels}.mp4')
        
        # record video file path and class label
        video_frames, video_frame_masks = video_frames_and_masks_from_sample(sample, padding_pixels=padding_pixels)
        combined_frames = combine_video_frames_and_masks(video_frames, video_frame_masks)
        assert combined_frames[0].shape[-1] == 3, "The number of channels of the combined frames should be 3."

        # # for debug
        # print("len video_frames: ", len(video_frames))
        # print("len masks video: ", len(video_frame_masks))
        # print("len combined_frames: ", len(combined_frames))

        gen_mp4_from_frames(video_frames, output_file, fps=fps)
        gen_mp4_from_frames(video_frame_masks, mask_output_file, fps=fps)
        gen_mp4_from_frames(combined_frames, combined_output_file, fps=fps)
        res_paths["video"].append(output_file)
        res_paths["mask"].append(mask_output_file)
        res_paths["combined"].append(combined_output_file)
        
        extra_sample_info = samples_info_list[i]
        res_extra_info.append(extra_sample_info)
    return res_paths, res_extra_info


ver = 8
# ver = "-test"
data_dir = Path(f'notebook_results/mmaction_train_data_v{ver}')
class_labels = ['mitosis', 'apoptosis', 'normal']
class_label = "mitosis"
frame_types = ["video", "mask", "combined"]
fps = 3

padding_pixels = [0, 20, 40, 50, 100, 200, 400]

def gen_samples_df(class2samples, class2sample_extra_info, data_dir, class_labels, padding_pixels, frame_types, fps, prefix=""):
    df_cols = ["path", "label_index", "padding_pixels", "frame_type", "src_dir"]
    sample_info_df = pd.DataFrame(columns=df_cols)
    for class_label in class_labels:
        output_dir = Path(data_dir) / "videos"
        output_dir.mkdir(exist_ok=True, parents=True)
        video_frames_samples = class2samples[class_label]
        video_frames_samples_info = class2sample_extra_info[class_label]
        for padding_pixel in padding_pixels:
            res_paths, res_extra_info = gen_samples_mp4s(video_frames_samples, video_frames_samples_info, class_label, output_dir, padding_pixels=padding_pixel, fps=fps, prefix=prefix)
            for selected_frame_type in frame_types:
                # mmaction_df = mmaction_df.append(pd.DataFrame([(str(path.name), class_labels.index(class_label), padding_pixel, selected_frame_type) for path in res_paths[selected_frame_type]], columns=["path", "label_index", "padding_pixels", "frame_type"]), ignore_index=True)
                sample_info_df = pd.concat([sample_info_df, pd.DataFrame([(str(path.name), 
                                                                    class_labels.index(class_label), 
                                                                    padding_pixel, selected_frame_type, res_extra_info[i]["src_dir"])
                                                                    for i, path in enumerate(res_paths[selected_frame_type])], columns=df_cols)])
    return sample_info_df


# split train and test data

# get #samples from all_class2samples
_split = 0.8

train_class2samples = {}
test_class2samples = {}
train_class2sample_extra_info = {}
test_class2sample_extra_info = {}
for key in all_class2samples.keys():
    split_idx = int(len(all_class2samples[key]) * _split)

    train_class2samples[key] = all_class2samples[key][:split_idx]
    test_class2samples[key] = all_class2samples[key][split_idx:]

    train_class2sample_extra_info[key] = all_class2sample_extra_info[key][:split_idx]
    test_class2sample_extra_info[key] = all_class2sample_extra_info[key][split_idx:]



In [None]:
len(train_class2samples["normal"]), len(test_class2samples["normal"])

In [None]:
len(train_class2samples["mitosis"]), len(test_class2samples["mitosis"])

In [None]:
video_frames_and_masks_from_sample(train_class2samples["normal"][6])[0][0].shape
# train_class2samples["normal"][6][1].show_panel()

In [None]:
idx_to_check = 6
video_frames, video_frame_masks = video_frames_and_masks_from_sample(train_class2samples["normal"][idx_to_check], padding_pixels=0)
print("video frames dtype:", video_frames[0].dtype)
print("video frames shape:", video_frames[0].shape)
print("video frame masks dtype:", video_frame_masks[0].dtype)
print("video frame masks shape:", video_frame_masks[0].shape)
combined_frames = combine_video_frames_and_masks(video_frames, video_frame_masks)
combined_frames = np.array(combined_frames).astype(np.uint8)
# combined_frames = np.maximum(combined_frames - 1, 0).astype(np.uint8)
print("combined_frames shape: ", combined_frames[0].shape)
gen_mp4_from_frames(combined_frames, "./test_video_output.mp4", fps=1)

In [None]:
np.array(combined_frames).flatten().min()

In [None]:

# # for debug
# train_class2samples = {key: value[:5] for key, value in all_class2samples.items()}
# test_class2samples = {key: value[:5] for key, value in all_class2samples.items()}
# padding_pixels = [20]


train_sample_info_df = gen_samples_df(train_class2samples, train_class2sample_extra_info, data_dir, class_labels, padding_pixels, frame_types, fps, prefix="train")
test_sample_info_df = gen_samples_df(test_class2samples, test_class2sample_extra_info, data_dir, class_labels, padding_pixels, frame_types, fps, prefix="test")

train_sample_info_df.to_csv(data_dir/f'train_data.txt', index=False, header=False, sep=' ', )
test_sample_info_df.to_csv(data_dir/f'test_data.txt', index=False, header=False, sep=' ', )

for selected_frame_type in frame_types:
    train_df_path = data_dir/f'mmaction_train_data_{selected_frame_type}.txt'
    train_selected_frame_type_df = train_sample_info_df[train_sample_info_df["frame_type"] == selected_frame_type]
    train_selected_frame_type_df = train_selected_frame_type_df.reset_index(drop=True)
    train_selected_frame_type_df = train_selected_frame_type_df[["path", "label_index"]]
    train_selected_frame_type_df.to_csv(train_df_path, index=False, header=False, sep=' ')
    
    test_df_path = data_dir/f'mmaction_test_data_{selected_frame_type}.txt'
    test_selected_frame_type_df = test_sample_info_df[test_sample_info_df["frame_type"] == selected_frame_type]
    test_selected_frame_type_df = test_selected_frame_type_df[["path", "label_index"]]
    test_selected_frame_type_df = test_selected_frame_type_df.reset_index(drop=True)
    test_selected_frame_type_df.to_csv(test_df_path, index=False, header=False, sep=' ')


# # the follwing code generates v1-v7 test data. The issue is that some of test data shows up in train data, through different padding values.
# data_df_path = data_dir/'all_data.txt'
# sample_df = gen_samples_df(data_dir, class_labels, padding_pixels, frame_types, fps)
# sample_df.to_csv(data_df_path, index=False, header=False, sep=' ')
# for selected_frame_type in frame_types:
#     selected_frame_type_df = sample_df[sample_df["frame_type"] == selected_frame_type]
#     selected_frame_type_df = selected_frame_type_df.reset_index(drop=True)
#     train_df_path = data_dir/f'train_data_{selected_frame_type}.txt'
#     test_df_path = data_dir/f'test_data_{selected_frame_type}.txt'
#     train_df = selected_frame_type_df.sample(frac=0.8, random_state=0, replace=False)
#     test_df = selected_frame_type_df.drop(train_df.index, inplace=False)

#     # only keep the path and label_index columns
#     train_df = train_df[["path", "label_index"]]
#     test_df = test_df[["path", "label_index"]]

#     train_df.to_csv(train_df_path, index=False, header=False, sep=' ')
#     test_df.to_csv(test_df_path, index=False, header=False, sep=' ')


In [None]:
train_class2samples

Check the videos

In [None]:

video_paths = list(Path(data_dir/'videos').glob('*.mp4'))

Due to a `decord` package [issue](https://github.com/dmlc/decord/issues/150), to use mmaction2 we must check if the videos can be loaded by `decord` correctly.

In [None]:
import decord
for path in video_paths:
# for path in ["./notebook_results/train_normal_6_raw_padding-0.mp4"]:
# for path in ["./test_video_output.mp4"]:
    reader = decord.VideoReader(str(path))
    reader.seek(0)
    imgs = list()
    frame_inds = range(0, len(reader))
    for idx in frame_inds:
        reader.seek(idx)
        frame = reader.next()
        imgs.append(frame.asnumpy())
        frame = frame.asnumpy()

        num_channels = frame.shape[-1]
        if num_channels != 3:
            print("invalid video for decord (https://github.com/dmlc/decord/issues/150): ", path)
            break
        # fig, axes = plt.subplots(1, num_channels, figsize=(20, 10))
        # for i in range(num_channels):
        #     axes[i].imshow(frame[:, :, i])
        # plt.show()
    del reader

In [None]:
decord.__version__

check if videos can be loaded by cv2 correctly

In [None]:
import cv2

cap = cv2.VideoCapture("./test_video_output.mp4")

while True:
    ret, frame = cap.read()
    if not ret:
        break
    assert frame.shape[-1] == 3, "frame should be in RGB format"

cap.release()
cv2.destroyAllWindows()

In [None]:
# from sklearn.model_selection import train_test_split

# train_df_path = data_dir/'train_data.csv'
# test_df_path = data_dir/'test_data.csv'

# # split train and test from df
# train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
# train_df.to_csv(train_df_path, index=False, header=False, sep=' ')
# test_df.to_csv(test_df_path, index=False, header=False, sep=' ')
