### References
https://github.com/vanvalenlab/deepcell-tf/blob/master/scripts/feature_pyramids/RetinaNet%20-%20Movie.ipynb

In [1]:
import os
import datetime
import errno
import argparse

import numpy as np

import deepcell

# Load data

In [2]:
from deepcell.utils.data_utils import get_data
from deepcell.utils.tracking_utils import load_trks

DATA_DIR = '/data/training_data/cells/3T3/NIH/movie'
DATA_FILE = os.path.join(DATA_DIR, 'nuclear_movie_3T3_0-2_same.trks')

# Load Information for hardcoded image size training
seed = 1
test_size = .2
train_dict, test_dict = get_data(DATA_FILE, mode='siamese_daughters', seed=seed, test_size=test_size)
X_train, y_train = train_dict['X'], train_dict['y']
X_test, y_test = test_dict['X'], test_dict['y']

print(' -\nX.shape: {}\ny.shape: {}'.format(train_dict['X'].shape, train_dict['y'].shape))

 -
X.shape: (192, 30, 154, 182, 1)
y.shape: (192, 30, 154, 182, 1)


# File Contants

In [3]:
# Set up other required filepaths
PREFIX = os.path.relpath(os.path.dirname(DATA_FILE), DATA_DIR)
ROOT_DIR = '/data' # mounted volume
MODEL_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'models', PREFIX))
LOG_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'logs', PREFIX))

# Model Parameters

In [4]:
# Each head of the model uses its own loss
from deepcell.losses import RetinaNetLosses
from deepcell.losses import discriminative_instance_loss


sigma = 3.0
alpha = 0.25
gamma = 2.0
iou_threshold = 0.5
max_detections = 100
mask_size = (28, 28)

retinanet_losses = RetinaNetLosses(
    sigma=sigma, alpha=alpha, gamma=gamma,
    iou_threshold=iou_threshold,
    mask_size=mask_size)

loss = {
    'regression': retinanet_losses.regress_loss,
    'classification': retinanet_losses.classification_loss,
    'association_features': discriminative_instance_loss,
    'masks': retinanet_losses.mask_loss,
    'final_detection': retinanet_losses.final_detection_loss,
}

# Create RetinaMask Model

In [5]:
from tensorflow.keras.optimizers import SGD, Adam
from deepcell.utils.train_utils import rate_scheduler

model_name = 'trackrcnn_model'
backbone = 'resnet50'  # vgg16, vgg19, resnet50, densenet121, densenet169, densenet201

n_epoch = 10  # Number of training epochs
lr = 1e-5

optimizer = Adam(lr=lr, clipnorm=0.001)

lr_sched = rate_scheduler(lr=lr, decay=0.99)

batch_size = 1

num_classes = 1  # "object" is the only class

In [6]:
from deepcell.utils.retinanet_anchor_utils import get_anchor_parameters

flat_shape = [y_train.shape[0] * y_train.shape[1]] + list(y_train.shape[2:])
flat_y = np.reshape(y_train, tuple(flat_shape)).astype('int')

# Generate backbone information from the data
backbone_levels, pyramid_levels, anchor_params = get_anchor_parameters(flat_y)

fpb = 3  # number of frames in each training batch

In [7]:
from deepcell import model_zoo

# Pass frames_per_batch > 1 to enable 3D mode!
model = model_zoo.RetinaMask(
    backbone=backbone,
    input_shape=X_train.shape[2:],
    frames_per_batch=fpb,
    class_specific_filter=False,
    num_classes=num_classes,
    backbone_levels=backbone_levels,
    pyramid_levels=pyramid_levels,
    anchor_params=anchor_params
)

prediction_model = model

# prediction_model = model_zoo.retinanet_bbox(
#     model,
#     panoptic=False,
#     frames_per_batch=fpb,
#     max_detections=100,
#     anchor_params=anchor_params)

W0204 23:43:46.667853 140457387505472 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0204 23:44:03.054301 140457387505472 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/initializers.py:143: calling RandomNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0204 23:44:14.536428 140457387505472 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py:255: calling crop_and_resize_v1 (from tensorflow.python.ops.im

rois.shape (?, ?, ?, ?, ?, ?)


In [8]:
model.summary()

Model: "resnet50_retinanet_mask"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image_input (InputLayer)        [(None, 3, 154, 182, 0                                            
__________________________________________________________________________________________________
time_distributed (TimeDistribut (None, 3, 154, 182,  0           image_input[0][0]                
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (None, 3, 154, 182,  6           time_distributed[0][0]           
__________________________________________________________________________________________________
time_distributed_3 (TimeDistrib (None, 3, 39, 46, 25 229760      time_distributed_1[0][0]         
____________________________________________________________________________

In [9]:
model.compile(loss=loss, optimizer=optimizer)

W0204 23:44:16.171706 140457387505472 training_utils.py:1101] Output mask_submodel missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to mask_submodel.
W0204 23:44:16.172880 140457387505472 training_utils.py:1101] Output time_distributed_9 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to time_distributed_9.
W0204 23:44:16.173635 140457387505472 training_utils.py:1101] Output time_distributed_14 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to time_distributed_14.
W0204 23:44:16.288187 140457387505472 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/deepcell/losses.py:332: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructi

# Train RetinaMask Model

### Training Parameters

In [10]:
from deepcell.image_generators import RetinaMovieDataGenerator

datagen = RetinaMovieDataGenerator(
    rotation_range=180,
    zoom_range=(0.8, 1.2),
    horizontal_flip=True,
    vertical_flip=True)

datagen_val = RetinaMovieDataGenerator()

In [11]:
train_data = datagen.flow(
    train_dict,
    batch_size=1,
    include_masks=True,
    include_final_detection_layer=True,
    assoc_head=True,
    frames_per_batch=fpb,
    pyramid_levels=pyramid_levels,
    anchor_params=anchor_params)

val_data = datagen_val.flow(
    test_dict,
    batch_size=1,
    include_masks=True,
    include_final_detection_layer=True,
    assoc_head=True,
    frames_per_batch=fpb,
    pyramid_levels=pyramid_levels,
    anchor_params=anchor_params)

train_dict keys :
X
y
daughters


W0204 23:44:20.387309 140457387505472 retinanet.py:651] Removing 2 of 192 images with fewer than 3 objects.


train_dict keys :
X
y
daughters


W0204 23:44:21.851473 140457387505472 retinanet.py:651] Removing 1 of 48 images with fewer than 3 objects.


In [12]:
from tensorflow.keras import callbacks
from deepcell.callbacks import RedirectModel, Evaluate

iou_threshold = 0.5
score_threshold = 0.01
max_detections = 100

model.fit_generator(
    train_data,
    steps_per_epoch=X_train.shape[0] // batch_size,
    epochs=n_epoch,
    validation_data=val_data,
    validation_steps=X_test.shape[0] // batch_size,
    callbacks=[
        callbacks.LearningRateScheduler(lr_sched),
        callbacks.ModelCheckpoint(
            os.path.join(MODEL_DIR, model_name + '.h5'),
            monitor='val_loss',
            verbose=1,
            save_best_only=True,
            save_weights_only=False),
        RedirectModel(
            Evaluate(val_data,
                     iou_threshold=iou_threshold,
                     score_threshold=score_threshold,
                     max_detections=max_detections,
                     frames_per_batch=fpb,
                     weighted_average=True),
            prediction_model)]
    )

y shape:  (154, 182, 1)
y_transform shape:  (154, 182, 13)
y shape:  (154, 182, 1)
y_transform shape:  (154, 182, 13)
y shape:  (154, 182, 1)
y_transform shape:  (154, 182, 13)
regressions.shape:  (1, 3, 83349, 5)
labels.shape:  (1, 3, 83349, 2)
assoc_heads_batch_shape:  (1, 3, 154, 1456)
ann['assoc_head'].shape:  (154, 182, 13)
y shape:  (154, 182, 1)
y_transform shape:  (154, 182, 27)
y shape:  (154, 182, 1)
y_transform shape:  (154, 182, 27)
y shape:  (154, 182, 1)
y_transform shape:  (154, 182, 27)
regressions.shape:  (1, 3, 83349, 5)
labels.shape:  (1, 3, 83349, 2)
assoc_heads_batch_shape:  (1, 3, 154, 1456)
ann['assoc_head'].shape:  (154, 182, 27)
y shape:  (154, 182, 1)
y_transform shape:  (154, 182, 7)
y shape:  (154, 182, 1)
y_transform shape:  (154, 182, 6)
y shape:  (154, 182, 1)
y_transform shape:  (154, 182, 7)
regressions.shape:  (1, 3, 83349, 5)
labels.shape:  (1, 3, 83349, 2)
assoc_heads_batch_shape:  (1, 3, 154, 1456)
ann['assoc_head'].shape:  (154, 182, 7)
y shape:  (

ValueError: could not broadcast input array from shape (2366) into shape (1456)

In [None]:
batch_inputs, batch_outputs = train_data.next()

In [None]:
for x in batch_inputs:
    print(x.shape)

In [None]:
for x in batch_outputs:
    print(x.shape)

In [21]:
batch_outputs[0]

array([[[[ 7.36526184e+01,  5.15165043e+00,  7.74914551e+01,
           3.63908730e+01,  0.00000000e+00],
         [ 5.89738693e+01,  4.60461617e+00,  6.09892578e+01,
           2.83677044e+01,  0.00000000e+00],
         [ 4.73233414e+01,  4.17043495e+00,  4.78914604e+01,
           2.19997158e+01,  0.00000000e+00],
         ...,
         [-1.47747576e+00, -1.42937860e+01, -2.05805826e+00,
          -1.47638836e+01, -1.00000000e+00],
         [-6.56924367e-01, -1.08292360e+01, -2.14923072e+00,
          -1.22338505e+01, -1.00000000e+00],
         [-5.65267634e-03, -8.07942200e+00, -2.22159410e+00,
          -1.02257624e+01, -1.00000000e+00]],

        [[ 2.05805826e+00,  1.61611652e+00,  2.80330086e+00,
           1.87132034e+01,  0.00000000e+00],
         [ 2.14923072e+00,  1.79846120e+00,  1.70923245e+00,
           1.43369303e+01,  0.00000000e+00],
         [ 2.22159410e+00,  1.94318831e+00,  8.40870261e-01,
           1.08634806e+01,  0.00000000e+00],
         ...,
         [-1.097