# 1. Data Preparation
Resave data as a set of tiff files in order to match Cell Tracking Challenge conventions which are expected by EmbedTrack.

In [3]:
import os

import numpy as np
from tifffile import imwrite

from deepcell_tracking.isbi_utils import trk_to_isbi
from deepcell_tracking.trk_io import load_trks

In [4]:
source_data = '/data/test.trks'
data_dir = '/EmbedTrack/data'

Load the test split of the tracking data

In [5]:
data = load_trks(source_data)

In [14]:
def find_zero_padding(X):
    """Remove zero padding to avoid adverse effects on model performance"""
    # Calculate position of padding based on first frame
    # Assume that padding is in blocks on the edges of image
    good_rows = np.where(X[0].any(axis=0))[0]
    good_cols = np.where(X[0].any(axis=1))[0]

    slc = (
        slice(None),
        slice(good_cols[0], good_cols[-1] + 1),
        slice(good_rows[0], good_rows[-1] + 1),
        slice(None)
    )

    return slc

Convert each batch of the test split to the standard ISBI format which is compatible with most of the models that we will tst.

In [18]:
for batch_no in range(len(data['lineages'])):
    # Build subdirectories for data
    raw_dir = os.path.join(data_dir, '{:03}'.format(batch_no + 1))
    gt_dir = os.path.join(data_dir, '{:03}_GT'.format(batch_no + 1))
    seg_dir = os.path.join(gt_dir, 'SEG')
    tra_dir = os.path.join(gt_dir, 'TRA')
    
    # Create directories if needed
    for d in (raw_dir, gt_dir, seg_dir, tra_dir):
        if not os.path.exists(d):
            os.makedirs(d)
                
    # Pull out relevant data for this batch
    X = data['X'][batch_no]
    y = data['y'][batch_no]
    lineages = data['lineages'][batch_no]
    
    # Determine position of zero padding for removal
    slc = find_zero_padding(X)
    X = X[slc]
    y = y[slc]
    
    # Need to translate lineages and adjust images to match restrictive ISBI format
    # Prepare output txt
    text_file = os.path.join(tra_dir, 'man_track.txt')
    df = trk_to_isbi(lineages)
    df.to_csv(text_file, sep=' ', header=False)
    
    # Determine which frames are zero padding
    frames = np.sum(y, axis=(1,2)) # True if image not blank
    good_frames = np.where(frames)[0]
    # We assume here that the empty frames are at the end of the movie (padding rather than skipped)
    movie_len = len(good_frames)
    
    # Save each frame of the movie as an individual tif
    channel = 0 # These images should only have one channel
    for i in range(movie_len):
        name_raw = os.path.join(raw_dir, 't{:03}.tif'.format(i))
        name_tracked_seg = os.path.join(seg_dir, 'man_seg{:03}.tif'.format(i))
        name_tracked_tra = os.path.join(tra_dir, 'man_track{:03}.tif'.format(i))
        
        imwrite(name_raw, X[i, ..., channel])
        imwrite(name_tracked_seg, y[i, ..., channel].astype('uint16'))
        imwrite(name_tracked_tra, y[i, ..., channel].astype('uint16'))

# 2. EmbedTrack Inference

<div class="alert alert-block alert-warning">
<b>Warning:</b> This notebook must be moved into the `embedtrack` folder in order to correctly import `embedtrack` modules.
</div>

In [1]:
import os
import re
import shutil

In [2]:
batch_size = 32

model_dir = '/EmbedTrack/KIT-Loe-GE/models'
models = [
    'Fluo-N2DL-HeLa',
    'Fluo-N2DH-SIM+',
    'Fluo-N2DH-GOWT1'
]

data_dir = '/EmbedTrack/data'

pattern = re.compile('\d{3}')
data_ids = [f for f in os.listdir(data_dir) if pattern.fullmatch(f)]

In [3]:
import json
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.signal.windows import gaussian
import tifffile
import torch

from embedtrack.infer.inference import (
    extend_grid,
    infer_sequence,
    create_inference_dict,
    calc_padded_img_size,
    init_model,
    foi_correction,
    rename_to_ctc_format,
    device,
)
from embedtrack.infer.infer_ctc_data import fill_empty_frames
from embedtrack.utils.clustering import Cluster
from embedtrack.utils.create_dicts import create_model_dict
from embedtrack.utils.utils import get_img_files


# This is a modified versioon of embedtrack.infer.infer_ctc_data.inference
# which eliminates the requirement that the data name matches the model name

def inference(raw_data_path, model_path, config_file, batch_size=32):
    """
    Segment and track a ctc dataset using a trained EmbedTrack model.
    Args:
        raw_data_path: string
            Path to the raw images
        model_path: string
            Path to the weights of the trained model
        config_file: string
            Path to the configuration of the model
        batch_size: int
            batch size during inference
    """
    raw_data_path = Path(raw_data_path)
    model_path = Path(model_path)

    data_id = raw_data_path.parts[-1]
    data_set = raw_data_path.parts[-2]

    ctc_res_path = raw_data_path.parent / (data_id + "_RES")
    temp_res_path = "./temp"
    if not os.path.exists(temp_res_path):
        os.makedirs(temp_res_path)
    else:
        shutil.rmtree(temp_res_path)

    # These lines are the modification
    # if data_set not in model_path.as_posix():
    #     raise Warning(f"The model {model_path} is not named as the data set {data_set}")

    overlap = 0.25

    with open(config_file) as file:
        train_config = json.load(file)

    model_class = train_config["model_dict"]["name"]
    crop_size = train_config["train_dict"]["crop_size"]

    image_size = tifffile.imread(
        os.path.join(raw_data_path, os.listdir(raw_data_path)[0])
    ).shape

    project_config = dict(
        image_dir=raw_data_path,
        res_dir=temp_res_path,
        model_cktp_path=model_path,
        model_class=model_class,
        grid_y=train_config["grid_dict"]["grid_y"],
        grid_x=train_config["grid_dict"]["grid_x"],
        pixel_y=train_config["grid_dict"]["pixel_y"],
        pixel_x=train_config["grid_dict"]["pixel_x"],
        overlap=overlap,
        crop_size=crop_size,  # multiple of 2
        img_size=image_size,
        padded_img_size=None,
    )
    project_config["padded_img_size"] = calc_padded_img_size(
        project_config["img_size"],
        project_config["crop_size"],
        project_config["overlap"],
    )[0]
    window_function_1d = gaussian(
        project_config["crop_size"], project_config["crop_size"] // 4
    )
    project_config["window_func"] = window_function_1d.reshape(
        -1, 1
    ) * window_function_1d.reshape(1, -1)

    dataset_dict = create_inference_dict(
        batch_size=batch_size,
    )

    # init model
    input_channels = train_config["model_dict"]["kwargs"]["input_channels"]
    n_classes = train_config["model_dict"]["kwargs"]["n_classes"]
    model_dict = create_model_dict(
        input_channels=input_channels,
        n_classes=n_classes,
    )
    model = init_model(model_dict, project_config)
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    model = model.to(device)
    model.eval()

    # clustering
    cluster = Cluster(
        project_config["grid_y"],
        project_config["grid_x"],
        project_config["pixel_y"],
        project_config["pixel_x"],
    )
    cluster = extend_grid(cluster, image_size)
    tracking_dir = os.path.join(project_config["res_dir"], "tracking")
    infer_sequence(
        model,
        dataset_dict,
        model_dict,
        project_config,
        cluster,
        min_mask_size=train_config["train_dict"]["min_mask_size"] * 0.5,
    )
    foi_correction(tracking_dir, data_set)
    fill_empty_frames(tracking_dir)
    lineage = pd.read_csv(
        os.path.join(tracking_dir, "res_track.txt"), sep=" ", header=None
    )
    max_id = lineage[0].index.max()
    if max_id >= 2 ** 16 - 1:
        raise AssertionError(
            "Max Track id > 2**16 - uint16 transformation needed for ctc"
            " measure will lead to buffer overflow!"
        )
    rename_to_ctc_format(tracking_dir, ctc_res_path)
    shutil.rmtree(temp_res_path)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
for m in models:
    print(f'------ {m} ------')
    res_subdir = os.path.join(data_dir, 'results', m)
    if not os.path.exists(res_subdir):
        os.makedirs(res_subdir)

    model_path = os.path.join(model_dir, m, 'best_iou_model.pth')
    config_file = os.path.join(model_dir, m, 'config.json')

    for data_id in data_ids:
        data_path = os.path.join(data_dir, data_id)
        temp_results_path = os.path.join(data_dir, f'{data_id}_RES')
        final_results_path = os.path.join(res_subdir, f'{data_id}_RES')
        if os.path.exists(final_results_path):
            print(f'Skipping {data_id}, already complete')
            continue
        try:
            inference(data_path, model_path, config_file, batch_size=batch_size)
            # Move results into the results subdirectory
            shutil.move(temp_results_path, final_results_path)
        except TypeError:
            print('Issue with', data_path)


------ Fluo-N2DL-HeLa ------
Skipping 011, already complete
Skipping 002, already complete
Skipping 009, already complete
Skipping 003, already complete
Skipping 012, already complete
Skipping 007, already complete
Skipping 001, already complete
Skipping 006, already complete
Skipping 010, already complete
Skipping 004, already complete
Skipping 008, already complete
Skipping 005, already complete
------ Fluo-N2DH-SIM+ ------
Skipping 011, already complete
Skipping 002, already complete
Skipping 009, already complete
Skipping 003, already complete
Skipping 012, already complete
Skipping 007, already complete
Skipping 001, already complete
Skipping 006, already complete
Skipping 010, already complete
Skipping 004, already complete
Skipping 008, already complete
Skipping 005, already complete
------ Fluo-N2DH-GOWT1 ------
Skipping 011, already complete
Skipping 002, already complete
Skipping 009, already complete
Skipping 003, already complete
`model_dict` dictionary successfully created

# 3. Evaluation

In [1]:
import glob
import os
import re

import numpy as np
import pandas as pd
from tifffile import imread

from deepcell_tracking.metrics import TrackingMetrics

In [3]:
data_dir = 'data'#'/EmbedTrack/data'
results_dir = 'data/results'#'/EmbedTrack/data/results'

pattern = re.compile('\d{3}')
data_ids = [f for f in os.listdir(data_dir) if pattern.fullmatch(f)]

node_match_threshold = 0.6

In [4]:
benchmarks = []

for m in [d for d in os.listdir(results_dir) if os.path.isdir(os.path.join(results_dir, d))]:
    print(f'------ {m} ------')
    res_subdir = os.path.join(results_dir, m)

    for data_id in data_ids:
        gt_dir = os.path.join(data_dir, f'{data_id}_GT/TRA')
        res_dir = os.path.join(res_subdir, f'{data_id}_RES')

        try:
            metrics = TrackingMetrics.from_isbi_dirs(gt_dir, res_dir)
            benchmarks.append({
                'model': m,
                'data_id': data_id,
                **metrics.stats
            })
        except ValueError:
            print('Issue with', data_id)
        
df = pd.DataFrame(benchmarks)
df.to_csv('benchmarks.csv')

------ Fluo-N2DH-SIM+ ------
missed node 14_10 division completely
missed node 17_24 division completely
missed node 18_9 division completely
missed node 21_42 division completely
missed node 25_36 division completely
missed node 37_35 division completely
missed node 38_5 division completely
missed node 41_26 division completely
missed node 43_4 division completely
missed node 46_35 division completely
missed node 49_6 division completely
missed node 50_42 division completely
missed node 54_28 division completely
missed node 57_12 division completely
missed node 73_5 division completely
missed node 21_19 division completely
missed node 5_6 division completely
missed node 10_66 division completely
missed node 15_66 division completely
missed node 16_9 division completely
missed node 22_56 division completely
missed node 25_57 division completely
missed node 26_3 division completely
missed node 28_39 division completely
missed node 30_37 division completely
missed node 35_52 division com