# 1. Inference

In [9]:
import os

import numpy as np
import tensorflow as tf

from deepcell.applications import NuclearSegmentation, CellTracking
from deepcell_tracking.trk_io import load_trks, save_trk

In [11]:
source_data = '/publication-tracking/data/test.trks'

data_dir = '/publication-tracking/benchmarking/DeepCell/data'
gt_seg_dir = os.path.join(data_dir, 'SEG_GT')
pred_seg_dir = os.path.join(data_dir, 'SEG_PRED')
gt_dir = os.path.join(data_dir, 'GT')

for d in [data_dir, gt_seg_dir, pred_seg_dir, gt_dir]:
    if not os.path.exists(d):
        os.makedirs(d)
        
model_urls = {
    'NuclearSegmentation': 'https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearSegmentation-7.tar.gz',
    'NuclearTrackingNE': 'https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingNE-7.tar.gz',
    'NuclearTrackingInf': 'https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingInf-7.tar.gz'
}

In [7]:
# Load test data
data = load_trks(source_data)

In [12]:
# Download and load each model
models = {}
for m, url in model_urls.items():
    archive_path = tf.keras.utils.get_file(f'{m}.tgz', url, extract=True, cache_subdir='models')
    model_path = os.path.splitext(archive_path)[0]
    model = tf.keras.models.load_model(model_path)
    models[m] = model

Downloading data from https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingNE-7.tar.gz
Downloading data from https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingInf-7.tar.gz


In [13]:
# Load segmentation and tracking applications
app_seg = NuclearSegmentation(models['NuclearSegmentation'])
app_trk = CellTracking(models['NuclearTrackingInf'], models['NuclearTrackingNE'])

In [14]:
def find_zero_padding(X):
    """Remove zero padding to avoid adverse effects on model performance"""
    # Calculate position of padding based on first frame
    # Assume that padding is in blocks on the edges of image
    good_rows = np.where(X[0].any(axis=0))[0]
    good_cols = np.where(X[0].any(axis=1))[0]

    slc = (
        slice(None),
        slice(good_cols[0], good_cols[-1] + 1),
        slice(good_rows[0], good_rows[-1] + 1),
        slice(None)
    )

    return slc

In [20]:
for batch_no in range(len(data['lineages'])):
    # Pull out relevant data for this batch
    X = data['X'][batch_no]
    y = data['y'][batch_no]
    lineages = data['lineages'][batch_no]
    name = '{:03}.trk'.format(batch_no)

    # Determine position of zero padding for removal
    slc = find_zero_padding(X)
    X = X[slc]
    y = y[slc]

    # Determine which frames are zero padding
    frames = np.sum(y, axis=(1,2)) # True if image not blank
    good_frames = np.where(frames)[0]
    X = X[:len(good_frames)]
    y = y[:len(good_frames)]

    # Save GT data
    save_trk(
        filename=os.path.join(gt_dir, name),
        lineage=lineages, 
        raw=X,
        tracked=y
    )

    # Generate tracks on GT segmentations
    track_gt = app_trk.track(X, y)
    save_trk(
        filename=os.path.join(gt_seg_dir, name),
        lineage=track_gt['tracks'], 
        raw=track_gt['X'],
        tracked=track_gt['y_tracked']
    )

    # Generate nuclear segmentation predictions
    y_pred = app_seg.predict(y)
    # Generate tracks on predicted segmentation
    track_pred = app_trk.track(X, y_pred)
    save_trk(
        filename=os.path.join(pred_seg_dir, name),
        lineage=track_pred['tracks'], 
        raw=track_pred['X'],
        tracked=track_pred['y_tracked']
    )


  markers = h_maxima(image=maxima,


# 2. Evaluation

In [21]:
import glob
import os

import numpy as np
import pandas as pd

from deepcell_tracking.metrics import TrackingMetrics

In [22]:
data_dir = '/publication-tracking/benchmarking/DeepCell/data'
gt_seg_dir = os.path.join(data_dir, 'SEG_GT')
pred_seg_dir = os.path.join(data_dir, 'SEG_PRED')
gt_dir = os.path.join(data_dir, 'GT')

data_ids = os.listdir(gt_dir)

node_match_threshold = 0.6

In [25]:
benchmarks = []

for results_dir, s in zip([gt_seg_dir, pred_seg_dir], ['GT', 'Deepcell']):
    for data_id in data_ids:
        m = TrackingMetrics.from_trk_files(
            os.path.join(gt_dir, data_id),
            os.path.join(results_dir, data_id),
            threshold=node_match_threshold
        )
        benchmarks.append({
            'model': f'Deepcell - {s}',
            'data_id': os.path.splitext(data_id)[0],
            **m.stats
        })

df = pd.DataFrame(benchmarks)
df.to_csv('benchmarks.csv')

missed node 26_25 division completely
missed node 57_10 division completely
missed node 60_0 division completely
missed node 5_6 division completely
missed node 121_42 division completely
missed node 144_28 division completely
missed node 1_29 division completely
missed node 29_29 division completely
12_31 out degree = 2, daughters mismatch, gt and res degree equal.
18_16 out degree = 2, daughters mismatch, gt and res degree equal.
missed node 26_25 division completely
missed node 23_66 division completely
missed node 48_6 division completely
missed node 57_10 division completely
missed node 60_0 division completely
corrected division 23_66 as a frameshift division not an error
corrected division 48_6 as a frameshift division not an error
corrected division 57_10 as a frameshift division not an error
missed node 1_48 division completely
8_18 out degree = 2, daughters mismatch, gt and res degree equal.
missed node 10_19 division completely
15_17 out degree = 1, daughters mismatch.
17_6 