In [None]:
import glob
from pathlib import Path
import numpy as np
import pandas as pd
import umap
import matplotlib.pyplot as plt

from livecellx.core.datasets import LiveCellImageDataset
from livecellx.sample_data import tutorial_three_image_sys

from livecellx.trajectory.feature_extractors import (
    compute_haralick_features,
    compute_skimage_regionprops,
)
from livecellx.preprocess.utils import normalize_img_to_uint8
from livecellx.core.parallel import parallelize
from livecellx.core.io_sc import prep_scs_from_mask_dataset
from livecellx.core.single_cell import create_sc_table
import livecellx.core.single_cell
from livecellx.core.single_cell import SingleCellStatic
from livecellx.core.io_utils import LiveCellEncoder


# dataset_dir_path = Path(
#     "../datasets/celltrackingchallenge/BF-C2DL-HSC/BF-C2DL-HSC/01"
# )

# mask_dataset_path = Path("../datasets/celltrackingchallenge/BF-C2DL-HSC/BF-C2DL-HSC/01_GT/TRA")

In [None]:
def extract_time_from_celltracking_dataset_man_anno_filename(filename):
    """Extracts the time from a filename in the format 'man_seg00002.tif'.

    Args:
        filename (str): The filename to extract the time from.

    Returns:
        int: The time extracted from the filename.
    """
    # Find the position of the start of the time string
    time_start = filename.find("man_seg") + len("man_seg")

    # Find the position of the end of the time string
    time_end = filename.find(".tif")

    # Extract the time string from the filename
    time_str = filename[time_start:time_end]

    # Remove leading zeroes from the time string
    time_str = time_str.lstrip("0")

    # Convert the time string to an integer
    time = int(time_str)

    return time

def extract_time_from_celltracking_dataset_raw_data_filename(filename):
    """Extracts the time from a filename in the format 't00002.tif'.
    
    Args:
        filename (str): The filename to extract the time from.
        
    Returns:
        int: The time extracted from the filename.
    """
    # Find the position of the start of the time string
    time_start = filename.find("t") + len("t")
    
    # Find the position of the end of the time string
    time_end = filename.find(".tif")
    
    # Extract the time string from the filename
    time_str = filename[time_start:time_end]
    
    # Remove leading zeroes from the time string
    time_str = time_str.lstrip("0")
    
    if time_str == "":
        time_str = "0"
    # Convert the time string to an integer
    time = int(time_str)
    
    return time


mask_dataset_path = Path(
    "../datasets/celltrackingchallenge/DIC-C2DH-HeLa/DIC-C2DH-HeLa/01_GT/SEG"
)
mask_dataset = LiveCellImageDataset(mask_dataset_path, ext="tif")
time2url = dict(mask_dataset.time2url)
time2url = {extract_time_from_celltracking_dataset_man_anno_filename(v): v for _, v in time2url.items()}
time2url.pop(67)
mask_dataset.update_time2url(time2url)
dataset_dir_path = Path(
    "../datasets/celltrackingchallenge/DIC-C2DH-HeLa/DIC-C2DH-HeLa/01"
)


In [None]:

mask_times = [extract_time_from_celltracking_dataset_man_anno_filename(path) for _, path in mask_dataset.time2url.items()]
print(mask_times)

In [None]:
mask_times

In [None]:

img_paths = sorted(glob.glob(str((Path(dataset_dir_path) / Path("*.tif")))))

time2url = {}

# if time not in mask_times, then it means there is no ground truth for that timepoint 
for i, path in enumerate(img_paths):
    filename = Path(path).name
    time = extract_time_from_celltracking_dataset_raw_data_filename(filename)
    if time not in mask_times:
        continue

    time2url[time] = path.replace("\\", "/") # prevent windows paths

img_dataset = LiveCellImageDataset(time2url=time2url, ext="tif")

In [None]:
len(img_dataset), len(mask_dataset)

In [None]:
img_dataset.reindex_time2url_sequential()
mask_dataset.reindex_time2url_sequential()

In [None]:
out_dir = Path("application_results/celltrackingchallenge/testing")
out_dir.mkdir(exist_ok=True, parents=True)

Compute the features  
Read the features in the next section if you already computed the features.

In [None]:
from livecellx.core.io_sc import prep_scs_from_mask_dataset
scs = prep_scs_from_mask_dataset(mask_dataset, img_dataset)
print("Number of single cells:", len(scs))

In [None]:
from livecellx.core.single_cell import create_sctc_from_scs
from livecellx.core.sc_seg_operator import create_sc_seg_napari_ui
from livecellx.core.sct_operator import create_sctc_edit_viewer_by_interval

sdata = create_sctc_from_scs(scs)
sct_operator = create_sctc_edit_viewer_by_interval(sdata, img_dataset, span_interval=1000)

In [None]:
# from livecellx.track.movie import generate_single_trajectory_movie

# for track_id, traj in traj_collection:
#     generate_single_trajectory_movie(traj, save_path= out_dir / f"track_{track_id}.gif")

In [None]:
from cellpose import models, utils
from cellpose.io import imread
model = models.Cellpose(gpu=True, model_type="TN1")


In [None]:

model.sz.cp.train(train_data = images, train_labels = masks_png, batch_size=4, channels=[0, 0], n_epochs=500,  save_path=model_path)