In [1]:
import os
import errno
import numpy as np

import tensorflow as tf
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks

import deepcell
from deepcell import model_zoo
from deepcell import losses
from deepcell.utils.data_utils import get_data
from deepcell.utils.train_utils import rate_scheduler
from deepcell.utils.tracking_utils import load_trks
from deepcell.utils.tracking_utils import save_trks
from deepcell.utils.retinanet_anchor_utils import get_anchor_parameters
from deepcell.callbacks import RedirectModel, Evaluate
from deepcell.image_generators import RetinaMovieDataGenerator, RetinaNetGenerator
# from deepcell.model_zoo import shapemask_box

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

# Helper Functions
From `shape_mask` branch

In [2]:
def train_model(model,
                model_dir=None,
                model_name=None,
                train_dict=None,
                test_dict=None,
                batch_size=1,
                num_classes=1,
                fpb=1,
                backbone_levels=None,
                pyramid_levels=None,
                anchor_params=None,
                n_epoch=16,
                optimizer=Adam(lr=1e-5, clipnorm=0.001),
                lr_sched = rate_scheduler(lr=1e-5, decay=0.99)
                ):
    
    if fpb == 1:
        datagen = RetinaNetGenerator(
            rotation_range=180,
            zoom_range=(0.8, 1.2),
            horizontal_flip=True,
            vertical_flip=True)

        datagen_val = RetinaNetGenerator()

        train_data = datagen.flow(
            train_dict,
            batch_size=batch_size,
            include_masks=True,
            include_final_detection_layer=True,
            pyramid_levels=pyramid_levels,
            anchor_params=anchor_params)

        val_data = datagen_val.flow(
            test_dict,
            batch_size=batch_size,
            include_masks=True,
            include_final_detection_layer=True,
            pyramid_levels=pyramid_levels,
            anchor_params=anchor_params)
    
    else:
        datagen = RetinaMovieDataGenerator(
            rotation_range=180,
            zoom_range=(0.8, 1.2),
            horizontal_flip=True,
            vertical_flip=True)

        datagen_val = RetinaMovieDataGenerator()

        train_data = datagen.flow(
            train_dict,
            batch_size=batch_size,
            include_masks=True,
            include_final_detection_layer=True,
            frames_per_batch=fpb,
            pyramid_levels=pyramid_levels,
            anchor_params=anchor_params)

        val_data = datagen_val.flow(
            test_dict,
            batch_size=batch_size,
            include_masks=True,
            include_final_detection_layer=True,
            frames_per_batch=fpb,
            pyramid_levels=pyramid_levels,
            anchor_params=anchor_params)

    retinanet_losses = losses.RetinaNetLosses(
        sigma=3.0,
        alpha=0.25,
        gamma=2.0,
        iou_threshold=0.5,
        mask_size=(28,28))

    loss = {
        'regression': retinanet_losses.regress_loss,
        'classification': retinanet_losses.classification_loss,
        'masks': retinanet_losses.mask_loss,
        'final_detection': retinanet_losses.final_detection_loss
        }

    model.compile(loss=loss, optimizer=optimizer)

    iou_threshold = 0.5
    score_threshold = 0.01
    max_detections = 100

    model.fit_generator(
        train_data,
        steps_per_epoch=X_train.shape[0] * X_train.shape[1]// batch_size,
        epochs=n_epoch,
        validation_data=val_data,
        validation_steps=X_test.shape[0] * X_test.shape[1]// batch_size,
        callbacks=[
            callbacks.LearningRateScheduler(lr_sched),
            callbacks.ModelCheckpoint(
                os.path.join(model_dir, model_name + '.h5'),
                monitor='val_loss',
                verbose=1,
                save_best_only=True,
                save_weights_only=False),
            RedirectModel(
                Evaluate(val_data,
                         iou_threshold=iou_threshold,
                         score_threshold=score_threshold,
                         max_detections=max_detections,
                         frames_per_batch=fpb,
                         weighted_average=True),
                prediction_model)
        ])

    return None

# Train models

In [None]:
# download_datasets()

DATA_DIR = '/data/training_data/tracking_benchmark_data'

backbones = ['featurenet'] #, 'mobilenetv2', 'resnet50']
fpbs = [5, 3, 1]
all_data = '3T3_HeLa_HEK_RAW_cropped.npz'
datasets = [all_data]
temporal_modes = ['conv', 'gru', 'lstm', None]
shape_mask = False

n_epoch = 4
seed = 808


for dataset in datasets:
    num_classes=1
    test_size = 0.1 # % of data saved as test
    test_seed = 10

    filename = os.path.join(DATA_DIR, dataset)
    train_dict, test_dict = get_data(filename, seed=seed, test_size=test_size)
    print(' -\nX.shape: {}\ny.shape: {}'.format(train_dict['X'].shape, train_dict['y'].shape))
    X_train, y_train = train_dict['X'], train_dict['y']
    X_test, y_test = test_dict['X'], test_dict['y']
    y_train_reshaped = y_train.reshape((-1,  X_train.shape[2], X_train.shape[3], X_train.shape[4]))
    print("y_train_reshaped shape:", y_train_reshaped.shape)
    optimal_params = get_anchor_parameters(y_train.reshape((-1,  X_train.shape[2], X_train.shape[3], X_train.shape[4])))
    backbone_levels, pyramid_levels, anchor_params = optimal_params
    norm_method='whole_image'
    print("optimal_params: ", optimal_params)

    for backbone in backbones:
        if backbone == 'featurenet':
            use_imagenet=False
        else:
            use_imagenet=True

        for fpb in fpbs:
            if fpb == 1:
                train_dict = {'X':X_train.reshape((-1,X_train.shape[2], X_train.shape[3], X_train.shape[4])), 
                              'y': y_train.reshape((-1,y_train.shape[2], y_train.shape[3], y_train.shape[4]))}
                test_dict = {'X':X_test.reshape((-1, X_test.shape[2], X_test.shape[3], X_test.shape[4])), 
                            'y': y_test.reshape((-1, y_test.shape[2], y_test.shape[3], y_test.shape[4]))}
            else:
                train_dict = {'X':X_train, 'y':y_train}
                test_dict = {'X':X_test, 'y':y_test}
            print(' -\nX.shape: {}\ny.shape: {}'.format(train_dict['X'].shape, train_dict['y'].shape))


            for temporal_mode in temporal_modes:
                model = model_zoo.RetinaMask(backbone=backbone,
                                        use_imagenet=use_imagenet,
                                        panoptic=False,
                                        frames_per_batch=fpb,
                                        temporal_mode=temporal_mode,
                                        num_classes=num_classes,
                                        input_shape=X_train.shape[2:],
                                        anchor_params=anchor_params,
                                        class_specific_filter=False,
                                        backbone_levels=backbone_levels,
                                        pyramid_levels=pyramid_levels,
                                        norm_method=norm_method)
                prediction_model = model

                model_dir = '/data/models'
                model_name = backbone + '_' + 'fpb' + str(fpb) + '_' + str(temporal_mode) + '_' + dataset

                # Train model
                print("Training model: ", model_name)
                trained_model = train_model(model,
                            model_dir=model_dir,
                            model_name=model_name,
                            train_dict=train_dict,
                            test_dict=test_dict,
                            fpb=fpb,
                            backbone_levels=backbone_levels,
                            pyramid_levels=pyramid_levels,
                            anchor_params=anchor_params,
                            n_epoch=n_epoch,
                          )


# Benchmark

In [5]:
# Define data to load (raw images from trk test files)
RAW_BASE_DIR = '/data/training_data/tracking_benchmark_data/test'

raw_trks_3T3  = os.path.join(RAW_BASE_DIR, '3T3_NIH_test_BData.trks')
raw_trks_HEK  = os.path.join(RAW_BASE_DIR, 'HEK293_generic_test_BData.trks')
raw_trks_HeLa = os.path.join(RAW_BASE_DIR, 'HeLa_S3_test_BData.trks')
raw_trks_RAW  = os.path.join(RAW_BASE_DIR, 'RAW264_generic_test_BData.trks')

# raw_trks_files = [raw_trks_3T3, raw_trks_HEK, raw_trks_HeLa]
raw_trks_files = [raw_trks_3T3, raw_trks_HEK, raw_trks_HeLa, raw_trks_RAW]

model_dir = '/data/models'

DATA_DIR = '/data/training_data/tracking_benchmark_data'
dataset = '3T3_HeLa_HEK_RAW_cropped.npz'
filename = os.path.join(DATA_DIR, dataset)
test_size = 0.1 # % of data saved as test
seed = 808
train_dict, test_dict = get_data(filename, seed=seed, test_size=test_size)
X_train, y_train = train_dict['X'], train_dict['y']

In [None]:
# Make predictions on test data
from deepcell.utils.tracking_utils import load_trks
from deepcell.utils.retinanet_anchor_utils import evaluate

from skimage.morphology import remove_small_objects
import pandas as pd
from deepcell import metrics
from skimage.measure import label
from skimage import morphology
from skimage.morphology import watershed
from skimage.feature import peak_local_max


iou_threshold = 0.5
score_threshold = 0.01
max_detections = 100
num_classes=1

backbones = ['featurenet'] #, 'mobilenetv2', 'resnet50']
fpbs = [5, 3, 1]

temporal_modes = ['conv', 'gru', 'lstm', None]
shape_mask = False
training_optimal_params = get_anchor_parameters(y_train.reshape((-1,  X_train.shape[2], X_train.shape[3], X_train.shape[4])))

# Go through each Dataset (3T3, HEK293, HeLa, RAW264.7)
for set_num, dataset in enumerate(raw_trks_files):
    print("dataset: ", dataset)
    # Load the trk file       
    trks = load_trks(dataset)
    lineages, raw, tracked = trks['lineages'], trks['X'], trks['y']
    norm_method='whole_image'
    backbone_levels, pyramid_levels, anchor_params = training_optimal_params
    datagen_val = RetinaMovieDataGenerator()
    
    for backbone in backbones:
        if backbone == 'featurenet':
                use_imagenet=False
        else:
            use_imagenet=True
                
        for fpb in fpbs:
            print("frames per batch: ", fpb)
            for temporal_mode in temporal_modes:
                prediction_model = model_zoo.RetinaMask(backbone=backbone,
                                        use_imagenet=use_imagenet,
                                        panoptic=False,
                                        frames_per_batch=fpb,
                                        temporal_mode=temporal_mode,
                                        num_classes=num_classes,
                                        input_shape=trks['X'].shape[2:],
                                        anchor_params=anchor_params,
                                        class_specific_filter=False,
                                        backbone_levels=backbone_levels,
                                        pyramid_levels=pyramid_levels,
                                        norm_method=norm_method)
                # print(prediction_model.summary())

                model_dir = '/data/models/'
                model_name = backbone + '_' + 'fpb' + str(fpb) + '_' + str(temporal_mode) + '_' + '3T3_HeLa_HEK_RAW_cropped.npz'
                print(backbone + '_' + 'fpb' + str(fpb) + '_' + str(temporal_mode))
                prediction_model.load_weights(os.path.join(model_dir, model_name + '.h5'))

                Model_DF = pd.DataFrame(columns=['total_instances', 'mAP'])


                # Go through each batch (movie) in each dataset
                for batch_num, movie in enumerate(trks['X']):
                    print("batch_num: ", batch_num)
                    Lstats = []

                    # Predict on the raw data
                    X_test_temp = np.expand_dims(movie, axis=0)
                    y_test_temp = np.expand_dims(trks['y'][batch_num], axis=0)
                    # print("X_test_temp.shape", X_test_temp.shape)
                    
                    val_data = datagen_val.flow(
                                    {'X': X_test_temp, 'y': y_test_temp},
                                    batch_size=1,
                                    include_masks=True,
                                    include_final_detection_layer=True,
                                    frames_per_batch=fpb,
                                    pyramid_levels=pyramid_levels,
                                    anchor_params=anchor_params)
                    
                    recall, precision, average_precisions = evaluate(
                                                                val_data,
                                                                prediction_model,
                                                                frames_per_batch=fpb,
                                                                iou_threshold=iou_threshold,
                                                                score_threshold=score_threshold,
                                                                max_detections=max_detections,
                                                            )
                    # print(recall, precision, average_precisions)
                    # print("Mean recall: ", np.mean(recall))
                    # print("Mean precision: ", np.mean(precision))
                    total_instances = []
                    precisions = []

                    for label, (average_precision, num_annotations) in average_precisions.items():
#                         print('{:.0f} instances of class'.format(num_annotations),
#                               label, 'with average precision: {:.4f}'.format(average_precision))
                        total_instances.append(num_annotations)
                        precisions.append(average_precision)

                    if sum(total_instances) == 0:
                        pass
                        # print('No test instances found.')
                    else:
#                         print('mAP using the weighted average of precisions among classes: {:.4f}'.format(
#                             sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances)))
#                         print('mAP: {:.4f}'.format(sum(precisions) / sum(x > 0 for x in total_instances)))
                        Model_DF = Model_DF.append({'total_instances': sum(total_instances),
                                              'mAP': sum(precisions) / sum(x > 0 for x in total_instances)},
                                             ignore_index=True)
                print('\n\n')
                print(Model_DF)
                print('total_instances', Model_DF.sum(axis = 0, skipna = True)['total_instances'] )
                print('mAP', Model_DF.mean(axis = 0, skipna = True)['mAP'] )
                print('\n\n')      
                        

dataset:  /data/training_data/tracking_benchmark_data/test/3T3_NIH_test_BData.trks
frames per batch:  5
featurenet_fpb5_conv
batch_num:  0
batch_num:  1
batch_num:  2
batch_num:  3
batch_num:  4
batch_num:  5
batch_num:  6
batch_num:  7
batch_num:  8
batch_num:  9
batch_num:  10
batch_num:  11
batch_num:  12
batch_num:  13
batch_num:  14
batch_num:  15
batch_num:  16
batch_num:  17
batch_num:  18
batch_num:  19
batch_num:  20
batch_num:  21
batch_num:  22
batch_num:  23



    total_instances       mAP
0               9.0  1.000000
1               6.0  1.000000
2               9.0  0.988889
3              12.0  0.944205
4               6.0  1.000000
5               6.0  1.000000
6               7.0  1.000000
7               9.0  1.000000
8              10.0  1.000000
9              11.0  0.975758
10             10.0  1.000000
11              9.0  0.864198
12             12.0  1.000000
13             13.0  0.961767
14              6.0  1.000000
15              7.0  1.000000
16          

In [11]:
backbone = 'featurenet'
fpbs = 5
temporal_mode = 'conv'
shape_mask = False

prediction_model = model_zoo.RetinaMask(backbone=backbone,
                                        use_imagenet=use_imagenet,
                                        panoptic=False,
                                        frames_per_batch=fpb,
                                        temporal_mode=temporal_mode,
                                        num_classes=num_classes,
                                        input_shape=trks['X'].shape[2:],
                                        anchor_params=anchor_params,
                                        class_specific_filter=False,
                                        backbone_levels=backbone_levels,
                                        pyramid_levels=pyramid_levels,
                                        norm_method=norm_method)
print(prediction_model.summary())

model_dir = '/data/models/'
model_name = backbone + '_' + 'fpb' + str(fpb) + '_' + str(temporal_mode) + '_' + '3T3_HeLa_HEK_RAW_cropped.npz'
prediction_model.load_weights(model_dir + model_name + '.h5')

Model: "featurenet_retinanet_mask"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_25 (InputLayer)           [(None, 5, 154, 182, 0                                            
__________________________________________________________________________________________________
time_distributed_56 (TimeDistri (None, 5, 154, 182,  0           input_25[0][0]                   
__________________________________________________________________________________________________
time_distributed_57 (TimeDistri (None, 5, 154, 182,  6           time_distributed_56[0][0]        
__________________________________________________________________________________________________
time_distributed_59 (TimeDistri (None, 5, 39, 46, 32 29152       time_distributed_57[0][0]        
__________________________________________________________________________

In [15]:
outputs = prediction_model.predict(X_test_temp[0])
final_scores = outputs[-1]
selection = np.where(final_scores[0, i] > 0.75)

ValueError: Error when checking input: expected input_25 to have 5 dimensions, but got array with shape (30, 154, 182, 1)

In [32]:
training_optimal_params = get_anchor_parameters(y_train.reshape((-1,  X_train.shape[2], X_train.shape[3], X_train.shape[4])))


In [31]:
print(training_optimal_params[2].sizes)
print(training_optimal_params[2].strides)
print(training_optimal_params[2].ratios)
print(training_optimal_params[2].scales)

[ 8. 16. 32.]
[2. 4. 8.]
[0.25 0.5  1.   2.   4.  ]
[1, 1.2599210498948732, 1.5874010519681994]
