<div class="alert alert-block alert-info"> <b>NOTE</b> Please select the kernel `venv_embedtrack` for this notebook. </div>

# 1. Data Preparation
Resave data as a set of tiff files in order to match Cell Tracking Challenge conventions which are expected by EmbedTrack.

In [1]:
import os
import sys

import numpy as np

from deepcell.datasets import DynamicNuclearNetTracking

sys.path.append('..')
import utils

In [2]:
data_dir = 'data'
models = [
    'Fluo-N2DL-HeLa',
    'Fluo-N2DH-SIM+',
    'Fluo-N2DH-GOWT1'
]

Load the test split of the tracking data

In [3]:
# Load test data
dnn = DynamicNuclearNetTracking(version='1.1')
X, y, lineages = dnn.load_data(split='test')
data = {
    'X': X,
    'y': y,
    'lineages': lineages
}

Convert each batch of the test split to the standard ISBI format which is compatible with most of the models that we will test.

In [4]:
for batch_no in range(len(data['lineages'])):
    # Pull out relevant data for this batch
    X = data['X'][batch_no]
    y = data['y'][batch_no]
    lineage = data['lineages'][batch_no]

    # Correct discontiguous tracks, which are not allowed by CTC
    y, lineage = utils.convert_to_contiguous(y, lineage)

    # Determine position of zero padding for removal
    slc = utils.find_zero_padding(X)
    X = X[slc]
    y = y[slc]

    # Determine which frames are zero padding
    frames = np.sum(y, axis=(1,2)) # True if image not blank
    good_frames = np.where(frames)[0]
    X = X[:len(good_frames), ..., 0]
    y = y[:len(good_frames), ..., 0]

    # Save copes of raw and gt data in a folder for each model
    for m in models:
        m_dir = os.path.join(data_dir, m)
        if not os.path.exists(m_dir):
            os.makedirs(m_dir)

        utils.save_ctc_raw(m_dir, batch_no + 1, X)
        utils.save_ctc_gt(m_dir, batch_no + 1, y, lineage)

# 2. EmbedTrack Inference

<div class="alert alert-block alert-warning">
<b>Warning:</b> This notebook must be moved into the `embedtrack` folder in order to correctly import `embedtrack` modules.
</div>

In [5]:
import os
import re
import shutil

In [6]:
batch_size = 32

model_dir = '/notebooks/benchmarking/EmbedTrack/KIT-Loe-GE/models'
models = [
    'Fluo-N2DL-HeLa',
    'Fluo-N2DH-SIM+',
    'Fluo-N2DH-GOWT1'
]

data_dir = '/notebooks/benchmarking/EmbedTrack/data'

pattern = re.compile('\d{3}')
data_ids = [f for f in os.listdir(os.path.join(data_dir, models[0])) if pattern.fullmatch(f)]

In [7]:
"""
This is a modified versioon of embedtrack.infer.infer_ctc_data.inference
which eliminates the requirement that the data name matches the model name
"""


import json
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.signal.windows import gaussian
import tifffile
import torch

from embedtrack.infer.inference import (
    extend_grid,
    infer_sequence,
    create_inference_dict,
    calc_padded_img_size,
    init_model,
    foi_correction,
    rename_to_ctc_format,
    device,
)
from embedtrack.infer.infer_ctc_data import fill_empty_frames
from embedtrack.utils.clustering import Cluster
from embedtrack.utils.create_dicts import create_model_dict
from embedtrack.utils.utils import get_img_files


def inference(raw_data_path, model_path, config_file, batch_size=32):
    """
    Segment and track a ctc dataset using a trained EmbedTrack model.
    Args:
        raw_data_path: string
            Path to the raw images
        model_path: string
            Path to the weights of the trained model
        config_file: string
            Path to the configuration of the model
        batch_size: int
            batch size during inference
    """
    raw_data_path = Path(raw_data_path)
    model_path = Path(model_path)

    data_id = raw_data_path.parts[-1]
    data_set = raw_data_path.parts[-2]

    ctc_res_path = raw_data_path.parent / (data_id + "_RES")
    temp_res_path = "./temp"
    if not os.path.exists(temp_res_path):
        os.makedirs(temp_res_path)
    else:
        shutil.rmtree(temp_res_path)

    # These lines are the modification
    # if data_set not in model_path.as_posix():
    #     raise Warning(f"The model {model_path} is not named as the data set {data_set}")

    overlap = 0.25

    with open(config_file) as file:
        train_config = json.load(file)

    model_class = train_config["model_dict"]["name"]
    crop_size = train_config["train_dict"]["crop_size"]

    image_size = tifffile.imread(
        os.path.join(raw_data_path, os.listdir(raw_data_path)[0])
    ).shape

    project_config = dict(
        image_dir=raw_data_path,
        res_dir=temp_res_path,
        model_cktp_path=model_path,
        model_class=model_class,
        grid_y=train_config["grid_dict"]["grid_y"],
        grid_x=train_config["grid_dict"]["grid_x"],
        pixel_y=train_config["grid_dict"]["pixel_y"],
        pixel_x=train_config["grid_dict"]["pixel_x"],
        overlap=overlap,
        crop_size=crop_size,  # multiple of 2
        img_size=image_size,
        padded_img_size=None,
    )
    project_config["padded_img_size"] = calc_padded_img_size(
        project_config["img_size"],
        project_config["crop_size"],
        project_config["overlap"],
    )[0]
    window_function_1d = gaussian(
        project_config["crop_size"], project_config["crop_size"] // 4
    )
    project_config["window_func"] = window_function_1d.reshape(
        -1, 1
    ) * window_function_1d.reshape(1, -1)

    dataset_dict = create_inference_dict(
        batch_size=batch_size,
    )

    # init model
    input_channels = train_config["model_dict"]["kwargs"]["input_channels"]
    n_classes = train_config["model_dict"]["kwargs"]["n_classes"]
    model_dict = create_model_dict(
        input_channels=input_channels,
        n_classes=n_classes,
    )
    model = init_model(model_dict, project_config)
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    model = model.to(device)
    model.eval()

    # clustering
    cluster = Cluster(
        project_config["grid_y"],
        project_config["grid_x"],
        project_config["pixel_y"],
        project_config["pixel_x"],
    )
    cluster = extend_grid(cluster, image_size)
    tracking_dir = os.path.join(project_config["res_dir"], "tracking")
    infer_sequence(
        model,
        dataset_dict,
        model_dict,
        project_config,
        cluster,
        min_mask_size=train_config["train_dict"]["min_mask_size"] * 0.5,
    )
    foi_correction(tracking_dir, data_set)
    fill_empty_frames(tracking_dir)
    lineage = pd.read_csv(
        os.path.join(tracking_dir, "res_track.txt"), sep=" ", header=None
    )
    max_id = lineage[0].index.max()
    if max_id >= 2 ** 16 - 1:
        raise AssertionError(
            "Max Track id > 2**16 - uint16 transformation needed for ctc"
            " measure will lead to buffer overflow!"
        )
    rename_to_ctc_format(tracking_dir, ctc_res_path)
    shutil.rmtree(temp_res_path)

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
for m in models:
    print(f'------ {m} ------')
    res_subdir = os.path.join(data_dir, m)
    if not os.path.exists(res_subdir):
        os.makedirs(res_subdir)

    model_path = os.path.join(model_dir, m, 'best_iou_model.pth')
    config_file = os.path.join(model_dir, m, 'config.json')

    for data_id in data_ids:
        data_path = os.path.join(res_subdir, data_id)
        final_results_path = data_path + '_RES'
        if os.path.exists(final_results_path):
            print(f'Skipping {data_id}, already complete')
            continue
        try:
            inference(data_path, model_path, config_file, batch_size=batch_size)
        except TypeError:
            print('Issue with', data_path)


------ Fluo-N2DL-HeLa ------
`model_dict` dictionary successfully created with: 
 -- num of classes equal to [4, 1, 2], 
 -- input channels equal to 1, 
 -- name equal to 2d
Creating branched erfnet with [4, 1, 2] classes
Save tracking mask t044.tif
Save tracking mask t043.tif
Save tracking mask t042.tif
Save tracking mask t041.tif
Save tracking mask t040.tif
Save tracking mask t039.tif
Save tracking mask t038.tif
Save tracking mask t037.tif
Save tracking mask t036.tif
Save tracking mask t035.tif
Save tracking mask t034.tif
Save tracking mask t033.tif
Save tracking mask t032.tif
Save tracking mask t031.tif
Save tracking mask t030.tif
Save tracking mask t029.tif
Save tracking mask t028.tif
Save tracking mask t027.tif
Save tracking mask t026.tif
Save tracking mask t025.tif
Save tracking mask t024.tif
Save tracking mask t023.tif
Save tracking mask t022.tif
Save tracking mask t021.tif
Save tracking mask t020.tif
Save tracking mask t019.tif
Save tracking mask t018.tif
Save tracking mask t01

# 3. Evaluation

In [9]:
import glob
import os
import re
import subprocess

import numpy as np
import pandas as pd
from tifffile import imread

from deepcell_tracking.metrics import TrackingMetrics

In [10]:
data_dir = '/notebooks/benchmarking/EmbedTrack/data'
models = [
    'Fluo-N2DL-HeLa',
    'Fluo-N2DH-SIM+',
    'Fluo-N2DH-GOWT1'
]

pattern = re.compile('\d{3}')

node_match_threshold = 0.6

ctc_software = '/notebooks/benchmarking/CTC_Evaluation_Software'
operating_system = 'Linux' # or 'Mac' or 'Win'
num_digits = '3'

In [11]:
benchmarks = []

for m in models:
    print(f'------ {m} ------')
    res_subdir = os.path.join(data_dir, m)
    data_ids = [f for f in os.listdir(res_subdir) if pattern.fullmatch(f)]

    for data_id in data_ids:
        results = {
            'model': f'EmbedTrack - {m}',
            'data_id': data_id
        }
        gt_dir = os.path.join(res_subdir, f'{data_id}_GT/TRA')
        res_dir = os.path.join(res_subdir, f'{data_id}_RES')

        # Deepcell division benchmarking
        try:
            metrics = TrackingMetrics.from_isbi_dirs(gt_dir, res_dir)
            results.update(metrics.stats)
        except ValueError:
            print('Issue with deepcell benchmarking of', data_id)

       # CTC metrics
        for metric, path in [('DET', 'DETMeasure'), ('SEG', 'SEGMeasure'), ('TRA', 'TRAMeasure')]:
            p = subprocess.run([os.path.join(ctc_software, operating_system, path), res_subdir, data_id, num_digits],
                               stdout=subprocess.PIPE)
            outstring = p.stdout

            try:
                val = float(outstring.decode('utf-8').split()[-1])
                results[metric] = val
            except:
                print('Benchmarking failure', path, m, data_id)
                print(outstring.decode('utf-8'))

        benchmarks.append(results)

df = pd.DataFrame(benchmarks)
df.to_csv('/notebooks/benchmarking/EmbedTrack/benchmarks.csv')

------ Fluo-N2DL-HeLa ------
missed node 17_24 division completely
missed node 18_9 division completely
missed node 21_42 division completely
missed node 37_35 division completely
missed node 41_26 division completely
missed node 43_4 division completely
missed node 54_28 division completely
missed node 73_5 division completely
missed node 5_6 division completely
missed node 10_66 division completely
missed node 15_66 division completely
missed node 30_37 division completely
55_27 out degree = 2, daughters mismatch, gt and res degree equal.
missed node 94_36 division completely
missed node 104_34 division completely
missed node 118_0 division completely
missed node 128_43 division completely
missed node 140_10 division completely
missed node 147_33 division completely
missed node 148_59 division completely
missed node 160_64 division completely
179_68 out degree = 2, daughters mismatch, gt and res degree equal.
missed node 185_49 division completely
missed node 1_29 division completely

In [12]:
df

Unnamed: 0,model,data_id,correct_division,mismatch_division,false_positive_division,false_negative_division,total_divisions,aa_tp,aa_total,te_tp,te_total,DET,SEG,TRA
0,Fluo-N2DL-HeLa,11,7.0,0.0,39.0,8.0,15.0,2822.0,3833.0,2923.0,3975.0,0.723899,0.502861,0.721602
1,Fluo-N2DL-HeLa,2,1.0,0.0,0.0,0.0,1.0,963.0,1075.0,988.0,1109.0,0.853021,0.704522,0.853646
2,Fluo-N2DL-HeLa,9,42.0,2.0,136.0,13.0,57.0,11958.0,14806.0,12276.0,15174.0,0.839851,0.588158,0.838117
3,Fluo-N2DL-HeLa,3,5.0,1.0,2.0,2.0,8.0,1837.0,2055.0,1896.0,2122.0,0.84689,0.702353,0.846557
4,Fluo-N2DL-HeLa,12,4.0,2.0,173.0,10.0,16.0,7638.0,9796.0,7863.0,10097.0,0.841111,0.605923,0.836927
5,Fluo-N2DL-HeLa,7,0.0,1.0,0.0,0.0,1.0,192.0,199.0,197.0,206.0,0.966019,0.860526,0.965911
6,Fluo-N2DL-HeLa,1,2.0,0.0,2.0,1.0,3.0,935.0,997.0,966.0,1030.0,0.688641,0.610207,0.68957
7,Fluo-N2DL-HeLa,6,0.0,0.0,6.0,0.0,0.0,264.0,383.0,274.0,397.0,0.806045,0.642576,0.802949
8,Fluo-N2DL-HeLa,10,30.0,4.0,142.0,20.0,54.0,7499.0,9090.0,7725.0,9357.0,0.852581,0.59033,0.84912
9,Fluo-N2DL-HeLa,4,1.0,0.0,8.0,0.0,1.0,369.0,521.0,386.0,540.0,0.625926,0.494621,0.627375
