In [1]:
import os
import errno
import numpy as np

import tensorflow as tf
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks

import deepcell
from deepcell import model_zoo
from deepcell import losses
from deepcell.utils.data_utils import get_data
from deepcell.utils.train_utils import rate_scheduler
from deepcell.utils.tracking_utils import load_trks
from deepcell.utils.tracking_utils import save_trks
from deepcell.utils.retinanet_anchor_utils import get_anchor_parameters
from deepcell.callbacks import RedirectModel, Evaluate
from deepcell.image_generators import RetinaMovieDataGenerator, RetinaNetGenerator
# from deepcell.model_zoo import shapemask_box

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

# Helper Functions
From `shape_mask` branch

In [2]:
def train_model(model,
                model_dir=None,
                model_name=None,
                train_dict=None,
                test_dict=None,
                batch_size=1,
                num_classes=1,
                fpb=1,
                backbone_levels=None,
                pyramid_levels=None,
                anchor_params=None,
                n_epoch=16,
                optimizer=Adam(lr=1e-5, clipnorm=0.001),
                lr_sched = rate_scheduler(lr=1e-5, decay=0.99)
                ):
    
    if fpb == 1:
        datagen = RetinaNetGenerator(
            rotation_range=180,
            zoom_range=(0.8, 1.2),
            horizontal_flip=True,
            vertical_flip=True)

        datagen_val = RetinaNetGenerator()

        train_data = datagen.flow(
            train_dict,
            batch_size=batch_size,
            include_masks=True,
            include_final_detection_layer=True,
            pyramid_levels=pyramid_levels,
            anchor_params=anchor_params)

        val_data = datagen_val.flow(
            test_dict,
            batch_size=batch_size,
            include_masks=True,
            include_final_detection_layer=True,
            pyramid_levels=pyramid_levels,
            anchor_params=anchor_params)
    
    else:
        datagen = RetinaMovieDataGenerator(
            rotation_range=180,
            zoom_range=(0.8, 1.2),
            horizontal_flip=True,
            vertical_flip=True)

        datagen_val = RetinaMovieDataGenerator()

        train_data = datagen.flow(
            train_dict,
            batch_size=batch_size,
            include_masks=True,
            include_final_detection_layer=True,
            frames_per_batch=fpb,
            pyramid_levels=pyramid_levels,
            anchor_params=anchor_params)

        val_data = datagen_val.flow(
            test_dict,
            batch_size=batch_size,
            include_masks=True,
            include_final_detection_layer=True,
            frames_per_batch=fpb,
            pyramid_levels=pyramid_levels,
            anchor_params=anchor_params)

    retinanet_losses = losses.RetinaNetLosses(
        sigma=3.0,
        alpha=0.25,
        gamma=2.0,
        iou_threshold=0.5,
        mask_size=(28,28))

    loss = {
        'regression': retinanet_losses.regress_loss,
        'classification': retinanet_losses.classification_loss,
        'masks': retinanet_losses.mask_loss,
        'final_detection': retinanet_losses.final_detection_loss
        }

    model.compile(loss=loss, optimizer=optimizer)

    iou_threshold = 0.5
    score_threshold = 0.01
    max_detections = 100

    model.fit_generator(
        train_data,
        steps_per_epoch=X_train.shape[0] * X_train.shape[1]// batch_size,
        epochs=n_epoch,
        validation_data=val_data,
        validation_steps=X_test.shape[0] * X_test.shape[1]// batch_size,
        callbacks=[
            callbacks.LearningRateScheduler(lr_sched),
            callbacks.ModelCheckpoint(
                os.path.join(model_dir, model_name + '.h5'),
                monitor='val_loss',
                verbose=1,
                save_best_only=True,
                save_weights_only=False),
            RedirectModel(
                Evaluate(val_data,
                         iou_threshold=iou_threshold,
                         score_threshold=score_threshold,
                         max_detections=max_detections,
                         frames_per_batch=fpb,
                         weighted_average=True),
                prediction_model)
        ])

    return None

# Train models

In [None]:
# download_datasets()

DATA_DIR = '/data/training_data/tracking_benchmark_data'

backbones = ['resnet50']
fpbs = [5, 3, 1]
all_data = '3T3_HeLa_HEK_RAW_cropped.npz'
datasets = [all_data]
temporal_modes = ['conv', 'gru', 'lstm', None]
shape_mask = False

n_epoch = 4
seed = 808


for dataset in datasets:
    num_classes=1
    test_size = 0.1 # % of data saved as test
    test_seed = 10

    filename = os.path.join(DATA_DIR, dataset)
    train_dict, test_dict = get_data(filename, seed=seed, test_size=test_size)
    print(' -\nX.shape: {}\ny.shape: {}'.format(train_dict['X'].shape, train_dict['y'].shape))
    X_train, y_train = train_dict['X'], train_dict['y']
    X_test, y_test = test_dict['X'], test_dict['y']
    y_train_reshaped = y_train.reshape((-1,  X_train.shape[2], X_train.shape[3], X_train.shape[4]))
    print("y_train_reshaped shape:", y_train_reshaped.shape)
    optimal_params = get_anchor_parameters(y_train.reshape((-1,  X_train.shape[2], X_train.shape[3], X_train.shape[4])))
    backbone_levels, pyramid_levels, anchor_params = optimal_params
    norm_method='whole_image'
    print("optimal_params: ", optimal_params)

    for backbone in backbones:
        if backbone == 'featurenet':
            use_imagenet=False
        else:
            use_imagenet=True

        for fpb in fpbs:
            if fpb == 1:
                train_dict = {'X':X_train.reshape((-1,X_train.shape[2], X_train.shape[3], X_train.shape[4])), 
                              'y': y_train.reshape((-1,y_train.shape[2], y_train.shape[3], y_train.shape[4]))}
                test_dict = {'X':X_test.reshape((-1, X_test.shape[2], X_test.shape[3], X_test.shape[4])), 
                            'y': y_test.reshape((-1, y_test.shape[2], y_test.shape[3], y_test.shape[4]))}
            else:
                train_dict = {'X':X_train, 'y':y_train}
                test_dict = {'X':X_test, 'y':y_test}
            print(' -\nX.shape: {}\ny.shape: {}'.format(train_dict['X'].shape, train_dict['y'].shape))


            for temporal_mode in temporal_modes:
                model = model_zoo.RetinaMask(backbone=backbone,
                                        use_imagenet=use_imagenet,
                                        panoptic=False,
                                        frames_per_batch=fpb,
                                        temporal_mode=temporal_mode,
                                        num_classes=num_classes,
                                        input_shape=X_train.shape[2:],
                                        anchor_params=anchor_params,
                                        class_specific_filter=False,
                                        backbone_levels=backbone_levels,
                                        pyramid_levels=pyramid_levels,
                                        norm_method=norm_method)
                prediction_model = model

                model_dir = '/data/models'
                model_name = backbone + '_' + 'fpb' + str(fpb) + '_' + str(temporal_mode) + '_' + dataset

                # Train model
                print("Training model: ", model_name)
                trained_model = train_model(model,
                            model_dir=model_dir,
                            model_name=model_name,
                            train_dict=train_dict,
                            test_dict=test_dict,
                            fpb=fpb,
                            backbone_levels=backbone_levels,
                            pyramid_levels=pyramid_levels,
                            anchor_params=anchor_params,
                            n_epoch=n_epoch,
                          )


 -
X.shape: (722, 30, 135, 160, 1)
y.shape: (722, 30, 135, 160, 1)
y_train_reshaped shape: (21660, 135, 160, 1)


W0111 03:11:13.316836 140242619287360 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


optimal_params:  (['C1', 'C2', 'C3'], ['P1', 'P2', 'P3'], <deepcell.utils.retinanet_anchor_utils.AnchorParameters object at 0x7f8ca4023828>)
 -
X.shape: (722, 30, 135, 160, 1)
y.shape: (722, 30, 135, 160, 1)
Downloading data from https://github.com/keras-team/keras-applications/releases/download/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


W0111 03:11:59.796795 140242619287360 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/initializers.py:143: calling RandomNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0111 03:12:14.272167 140242619287360 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py:255: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0111 03:12:15.545627 140242619287360 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py:255: calling crop_and_resize_v1 (from tensorflow.python.ops.image_ops_impl) with box_i

Training model:  resnet50_fpb5_conv_3T3_HeLa_HEK_RAW_cropped.npz


W0111 03:12:26.018308 140242619287360 retinanet.py:619] Removing 77 of 722 images with fewer than 3 objects.
W0111 03:12:29.030066 140242619287360 retinanet.py:619] Removing 7 of 81 images with fewer than 3 objects.
W0111 03:12:29.264892 140242619287360 training_utils.py:1101] Output filtered_detections missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to filtered_detections.
W0111 03:12:29.265924 140242619287360 training_utils.py:1101] Output filtered_detections_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to filtered_detections_1.
W0111 03:12:29.266691 140242619287360 training_utils.py:1101] Output filtered_detections_2 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to filtered_detections_2.
W0111 03:12:29.267596 1402426

Epoch 1/4
  472/21660 [..............................] - ETA: 5:54:43 - loss: 3.8637 - regression_loss: 2.2617 - classification_loss: 0.6959 - masks_loss: 0.4054 - final_detection_loss: 0.5006

# Benchmark

In [None]:
# Define data to load (raw images from trk test files)
RAW_BASE_DIR = '/data/training_data/tracking_benchmark_data/test'

raw_trks_3T3  = os.path.join(RAW_BASE_DIR, '3T3_NIH_test_BData.trks')
raw_trks_HEK  = os.path.join(RAW_BASE_DIR, 'HEK293_generic_test_BData.trks')
raw_trks_HeLa = os.path.join(RAW_BASE_DIR, 'HeLa_S3_test_BData.trks')
raw_trks_RAW  = os.path.join(RAW_BASE_DIR, 'RAW264_generic_test_BData.trks')

# raw_trks_files = [raw_trks_3T3, raw_trks_HEK, raw_trks_HeLa]
raw_trks_files = [raw_trks_3T3, raw_trks_HEK, raw_trks_HeLa, raw_trks_RAW]

model_dir = '/data/models'
model_name = backbone + '_' + 'fpb' + str(fpb) + '_' + str(temporal_mode) + '_' + dataset

backbones = ['featurenet', 'mobilenetv2', 'resnet50']
fpbs = [5, 3, 1]
all_data = '3T3_HeLa_HEK_RAW_cropped.npz'
datasets = [all_data]
temporal_modes = ['conv', 'gru', 'lstm', None]
shape_mask = False
norm_method = 'whole_image'
seed = 808

In [None]:
# Make predictions on test data
from deepcell.utils.tracking_utils import load_trks
from skimage.morphology import remove_small_objects
import pandas as pd
from deepcell import metrics
from skimage.measure import label
from skimage import morphology
from skimage.morphology import watershed
from skimage.feature import peak_local_max


# Go through each Dataset (3T3, HEK293, HeLa, RAW264.7)
for set_num, dataset in enumerate(raw_trks_files):
    # Load the trk file       
    trks = load_trks(dataset)
    lineages, raw, tracked = trks['lineages'], trks['X'], trks['y']
    optimal_params = determine_anchor_params(trks['y'].reshape((-1,  trks['X'].shape[2], trks['X'].shape[3], trks['X'].shape[4])))
    backbone_levels, pyramid_levels, anchor_params = optimal_params
    
    for backbone in backbones:
        if backbone == 'featurenet':
                use_imagenet=False
            else:
                use_imagenet=True

            for fpb in fpbs:
                for temporal_mode in temporal_modes: