### References
https://github.com/vanvalenlab/deepcell-tf/blob/master/scripts/feature_pyramids/RetinaNet%20-%20Movie.ipynb

In [1]:
import os
import datetime
import errno
import argparse

import numpy as np

import deepcell

# Load data

In [2]:
from deepcell.utils.data_utils import get_data
from deepcell.utils.tracking_utils import load_trks

DATA_DIR = '/data/training_data/cells/3T3/NIH/movie'
DATA_FILE = os.path.join(DATA_DIR, 'nuclear_movie_3T3_0-2_same.trks')

# Load Information for hardcoded image size training
seed = 1
test_size = .2
train_dict, test_dict = get_data(DATA_FILE, mode='siamese_daughters', seed=seed, test_size=test_size)
X_train, y_train = train_dict['X'], train_dict['y']
X_test, y_test = test_dict['X'], test_dict['y']

print(' -\nX.shape: {}\ny.shape: {}'.format(train_dict['X'].shape, train_dict['y'].shape))

 -
X.shape: (192, 30, 154, 182, 1)
y.shape: (192, 30, 154, 182, 1)


# File Contants

In [3]:
# Set up other required filepaths
PREFIX = os.path.relpath(os.path.dirname(DATA_FILE), DATA_DIR)
ROOT_DIR = '/data' # mounted volume
MODEL_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'models', PREFIX))
LOG_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'logs', PREFIX))

# Loss

In [4]:
def discriminative_instance_loss(y_true, 
                                 y_pred,
                                 delta_v=0.5,
                                 delta_d=1.5,
                                 gamma=1e-3):
    """Discriminative loss between an output tensor and a target tensor.

    Args:
        y_true: A tensor of the same shape as y_pred.
        y_pred: A tensor of the vector embedding

    Returns:
        tensor: Output tensor.
    """

    def temp_norm(ten, axis=None):
        if axis is None:
            axis = 1 if K.image_data_format() == 'channels_first' else K.ndim(ten) - 1
        return K.sqrt(K.epsilon() + K.sum(K.square(ten), axis=axis))

    if K.ndim(y_pred) == 4:
        y_pred_shape = tf.shape(y_pred)
        new_y_pred_shape = [y_pred_shape[0] * y_pred_shape[1],
                            y_pred_shape[2], y_pred_shape[3]]
        y_pred = tf.reshape(y_pred, new_y_pred_shape)
        print("new_y_pred_shape", y_pred.shape)

        y_true_shape = tf.shape(y_true)
        new_y_true_shape = [y_true_shape[0] * y_true_shape[1],
                            y_true_shape[2], y_true_shape[3]]
        y_true = tf.reshape(y_true, new_y_true_shape)
        print("new_y_true_shape", y_true.shape)
    
    # split up the different predicted blobs
    assoc_feature_channel_shape = y_pred_shape[-1] - 5
    boxes = y_pred[:, :, :4]
    assoc_heads = y_pred[:, :, 4:assoc_feature_channel_shape]
    channel_dim = y_pred.shape[-1]
    final_detection_scores = y_pred[:, :, -1]

    # split up the different blobs
    annotations = y_true[:, :, :4]
    labels = K.cast(y_true[:, :, 4:5], dtype='int32')
    width = K.cast(y_true[0, 0, 5], dtype='int32')
    height = K.cast(y_true[0, 0, 6], dtype='int32')
    max_N = K.cast(y_true[0, 0, 7], dtype='int32')
    assoc_heads_target = y_true[:, :, 8:]
    n_detections = y_true_shape[2]
    print("n_detections", n_detections)
    
    print("boxes shape", boxes.shape)
    print("annotations shape", annotations.shape)
    
    # reshape the assoc_heads back to their original size
    assoc_heads_target = K.reshape(assoc_heads_target, (K.shape(assoc_heads_target)[0],
                                            K.shape(assoc_heads_target)[1], max_N))
    assoc_heads = K.reshape(assoc_heads, (K.shape(assoc_heads)[0], K.shape(assoc_heads)[1], 
                                          K.shape(assoc_heads)[2]))
    print("assoc_heads shape", assoc_heads.shape)

   
    # temp = final_detection_scores[0,...,0]
    temp = final_detection_scores[0,...]
    top_vals, top_indices = tf.math.top_k(temp, k=n_detections, sorted=False)

    top_indices_shape = top_indices.get_shape().as_list()
    print("top_indices_shape", top_indices_shape)
    
    top_indices = tf.stack([top_indices, top_indices], axis=-1)
    frames_per_batch = 3
    top_indices = tf.stack([top_indices for l in range(frames_per_batch)], axis=0)
    top_indices_shape = top_indices.get_shape().as_list()
    print("top_indices_shape", top_indices_shape)
    
    filtered_y_pred = tf.gather_nd(assoc_heads, top_indices, batch_dims=0)
    print("gather filtered_y_pred shape", filtered_y_pred.shape)
    y_pred = filtered_y_pred
    
#     for i in range(top_indices_shape[0]):
#         for j in range(top_indices_shape[1]):
#             filtered_y_pred[i, j, :] = K.eval(assoc_heads[i, top_indices[i, j], :])

#     y_pred = tf.convert_to_tensor(filtered_y_pred, dtype='float32')
    print("filtered_y_pred shape", y_pred.shape)
    print("y_true shape", y_true.shape)
     
    rank = K.ndim(y_pred)
    channel_axis = 1 if K.image_data_format() == 'channels_first' else rank - 1
    axes = [x for x in list(range(rank)) if x != channel_axis]

    # Compute variance loss
    cells_summed = tf.tensordot(y_true, y_pred, axes=[axes, axes])
    n_pixels = K.cast(tf.count_nonzero(y_true, axis=axes), dtype=K.floatx()) + K.epsilon()
    n_pixels_expand = K.expand_dims(n_pixels, axis=1) + K.epsilon()
    mu = tf.divide(cells_summed, n_pixels_expand)

    delta_v = K.constant(delta_v, dtype=K.floatx())
    mu_tensor = tf.tensordot(y_true, mu, axes=[[channel_axis], [0]])
    L_var_1 = y_pred - mu_tensor
    L_var_2 = K.square(K.relu(temp_norm(L_var_1) - delta_v))
    L_var_3 = tf.tensordot(L_var_2, y_true, axes=[axes, axes])
    L_var_4 = tf.divide(L_var_3, n_pixels)
    L_var = K.mean(L_var_4)

    # Compute distance loss
    mu_a = K.expand_dims(mu, axis=0)
    mu_b = K.expand_dims(mu, axis=1)

    diff_matrix = tf.subtract(mu_b, mu_a)
    L_dist_1 = temp_norm(diff_matrix)
    L_dist_2 = K.square(K.relu(K.constant(2 * delta_d, dtype=K.floatx()) - L_dist_1))
    diag = K.constant(0, dtype=K.floatx()) * tf.diag_part(L_dist_2)
    L_dist_3 = tf.matrix_set_diag(L_dist_2, diag)
    L_dist = K.mean(L_dist_3)

    # Compute regularization loss
    L_reg = gamma * temp_norm(mu)
    L = L_var + L_dist + K.mean(L_reg)

    return L

# Model Parameters

In [5]:
# Each head of the model uses its own loss
from deepcell.losses import RetinaNetLosses
#from deepcell.losses import discriminative_instance_loss
from tensorflow.keras import losses

sigma = 3.0
alpha = 0.25
gamma = 2.0
iou_threshold = 0.5
max_detections = 100
mask_size = (28, 28)

retinanet_losses = RetinaNetLosses(
    sigma=sigma, alpha=alpha, gamma=gamma,
    iou_threshold=iou_threshold,
    mask_size=mask_size)

loss = {
    'regression': retinanet_losses.regress_loss,
    'classification': retinanet_losses.classification_loss,
#     'association_features': discriminative_instance_loss, #losses.kullback_leibler_divergence,
    'association_features_cat': discriminative_instance_loss, #losses.kullback_leibler_divergence,
    'masks': retinanet_losses.mask_loss,
    'final_detection': retinanet_losses.final_detection_loss,
}

# Create RetinaMask Model

In [6]:
from tensorflow.keras.optimizers import SGD, Adam
from deepcell.utils.train_utils import rate_scheduler

model_name = 'trackrcnn_model'
backbone = 'resnet50'  # vgg16, vgg19, resnet50, densenet121, densenet169, densenet201

n_epoch = 10  # Number of training epochs
lr = 1e-5

optimizer = Adam(lr=lr, clipnorm=0.001)

lr_sched = rate_scheduler(lr=lr, decay=0.99)

batch_size = 1

num_classes = 1  # "object" is the only class

In [7]:
from deepcell.utils.retinanet_anchor_utils import get_anchor_parameters

flat_shape = [y_train.shape[0] * y_train.shape[1]] + list(y_train.shape[2:])
flat_y = np.reshape(y_train, tuple(flat_shape)).astype('int')

# Generate backbone information from the data
backbone_levels, pyramid_levels, anchor_params = get_anchor_parameters(flat_y)

fpb = 3  # number of frames in each training batch

# Instantiate Model

In [8]:
# Copyright 2016-2019 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TrackRCNN models adapted from MaskRCNN and https://github.com/fizyr/keras-maskrcnn"""

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers import Add, Activation, Flatten, Dense
from tensorflow.python.keras.layers import Input, Concatenate
from tensorflow.python.keras.layers import TimeDistributed, Conv2D, Conv3D
from tensorflow.python.keras.layers import AveragePooling2D, AveragePooling3D
from tensorflow.python.keras.layers import GlobalAveragePooling2D, GlobalAveragePooling3D
from tensorflow.python.keras.layers import MaxPool2D, MaxPool3D, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.initializers import normal

from deepcell.layers import Cast, Shape, UpsampleLike
from deepcell.layers import Upsample, RoiAlign, ConcatenateBoxes
from deepcell.layers import ClipBoxes, RegressBoxes, FilterDetections
from deepcell.layers import TensorProduct, ImageNormalization2D, Location2D
from deepcell.layers import ImageNormalization3D, Location3D
from deepcell.model_zoo.retinanet import retinanet, __build_anchors
from deepcell.utils.retinanet_anchor_utils import AnchorParameters
from deepcell.utils.backbone_utils import get_backbone


def default_mask_model(num_classes,
                       pyramid_feature_size=256,
                       mask_feature_size=256,
                       roi_size=(14, 14),
                       mask_size=(28, 28),
                       name='mask_submodel',
                       mask_dtype=K.floatx(),
                       retinanet_dtype=K.floatx()):
    """Creates the default mask submodel.

    Args:
        num_classes (int): Number of classes to predict a score for at each
            feature level.
        pyramid_feature_size (int): The number of filters to expect from the
            feature pyramid levels.
        mask_feature_size (int): The number of filters to expect from the masks.
        roi_size (tuple): The number of filters to use in the Roi Layers.
        mask_size (tuple): The size of the masks.
        mask_dtype (str): Dtype to use for mask tensors.
        retinanet_dtype (str): Dtype retinanet models expect.
        name (str): The name of the submodel.

    Returns:
        tensorflow.keras.Model: a Model that predicts classes for
            each anchor.
    """
    options = {
        'kernel_size': 3,
        'strides': 1,
        'padding': 'same',
        'kernel_initializer': normal(mean=0.0, stddev=0.01, seed=None),
        'bias_initializer': 'zeros',
        'activation': 'relu',
    }

    inputs = Input(shape=(None, roi_size[0], roi_size[1], pyramid_feature_size))
    outputs = inputs

    # casting to the desidered data type, which may be different than
    # the one used for the underlying keras-retinanet model
    if mask_dtype != retinanet_dtype:
        outputs = TimeDistributed(
            Cast(dtype=mask_dtype),
            name='cast_masks')(outputs)

    for i in range(4):
        outputs = TimeDistributed(Conv2D(
            filters=mask_feature_size,
            **options
        ), name='roi_mask_{}'.format(i))(outputs)

    # perform upsampling + conv instead of deconv as in the paper
    # https://distill.pub/2016/deconv-checkerboard/
    outputs = TimeDistributed(
        Upsample(mask_size),
        name='roi_mask_upsample')(outputs)
    outputs = TimeDistributed(Conv2D(
        filters=mask_feature_size,
        **options
    ), name='roi_mask_features')(outputs)

    outputs = TimeDistributed(Conv2D(
        filters=num_classes,
        kernel_size=1,
        activation='sigmoid'
    ), name='roi_mask')(outputs)

    # casting back to the underlying keras-retinanet model data type
    if mask_dtype != retinanet_dtype:
        outputs = TimeDistributed(
            Cast(dtype=retinanet_dtype),
            name='recast_masks')(outputs)

    return Model(inputs=inputs, outputs=outputs, name=name)


def default_final_detection_model(pyramid_feature_size=256,
                                  final_detection_feature_size=256,
                                  roi_size=(14, 14),
                                  name='final_detection_submodel'):
    options = {
        'kernel_size': 3,
        'strides': 1,
        'padding': 'same',
        'kernel_initializer': normal(mean=0.0, stddev=0.01, seed=None),
        'bias_initializer': 'zeros',
        'activation': 'relu'
    }

    inputs = Input(shape=(None, roi_size[0], roi_size[1], pyramid_feature_size))
    outputs = inputs

    for i in range(2):
        outputs = TimeDistributed(Conv2D(
            filters=final_detection_feature_size,
            **options
        ), name='final_detection_submodel_conv1_block{}'.format(i))(outputs)
        outputs = TimeDistributed(Conv2D(
            filters=final_detection_feature_size,
            **options
        ), name='final_detection_submodel_conv2_block{}'.format(i))(outputs)
        outputs = TimeDistributed(MaxPool2D(
        ), name='final_detection_submodel_pool1_block{}'.format(i))(outputs)

    outputs = TimeDistributed(Conv2D(filters=final_detection_feature_size,
                                     kernel_size=3,
                                     padding='valid',
                                     kernel_initializer=normal(mean=0.0, stddev=0.01, seed=None),
                                     bias_initializer='zeros',
                                     activation='relu'))(outputs)

    outputs = TimeDistributed(Conv2D(filters=1,
                                     kernel_size=1,
                                     activation='sigmoid'))(outputs)

    outputs = Lambda(lambda x: tf.squeeze(x, axis=[2, 3]))(outputs)

    return Model(inputs=inputs, outputs=outputs, name=name)


def default_roi_submodels(num_classes,
                          num_association_features,
                          roi_size=(14, 14),
                          mask_size=(28, 28),
                          frames_per_batch=1,
                          mask_dtype=K.floatx(),
                          retinanet_dtype=K.floatx()):
    """Create a list of default roi submodels.

    The default submodels contains a single mask model.

    Args:
        num_classes (int): Number of classes to use.
        roi_size (tuple): The number of filters to use in the Roi Layers.
        mask_size (tuple): The size of the masks.
        mask_dtype (str): Dtype to use for mask tensors.
        retinanet_dtype (str): Dtype retinanet models expect.

    Returns:
        list: A list of tuple, where the first element is the name of the
            submodel and the second element is the submodel itself.
    """
    if frames_per_batch > 1:
        return [
            ('masks', TimeDistributed(
                default_mask_model(num_classes,
                                   name='mask_submodel_0',
                                   roi_size=roi_size,
                                   mask_size=mask_size,
                                   mask_dtype=mask_dtype,
                                   retinanet_dtype=retinanet_dtype), name='mask_submodel')),
            ('final_detection', TimeDistributed(default_final_detection_model(roi_size=roi_size), 
                                                name = 'final_detection_submodel')),
            ('association_features', TimeDistributed(
                                     association_vector_model(num_association_features,
                                         roi_size=roi_size,
                                         name='assoc_vec_submodel_0',
                                         frames_per_batch=frames_per_batch),
                                     name='assoc_head_submodel'))
        ]
    return [
        ('masks', default_mask_model(num_classes,
                                     roi_size=roi_size,
                                     mask_size=mask_size,
                                     mask_dtype=mask_dtype,
                                     retinanet_dtype=retinanet_dtype))
        # ('final_detection', default_final_detection_model(roi_size=roi_size))
        ]


def association_vector_model(num_association_features,
                             roi_size=(14, 14),
                             pyramid_feature_size=256,
                             frames_per_batch=1,
                             name='assoc_head_submodel'):
    options = {
        'kernel_size': 3,
        'strides': 1,
        'padding': 'same',
        'kernel_initializer': normal(mean=0.0, stddev=0.01, seed=None),
        'bias_initializer': 'zeros',
        'activation': 'relu'
    }

    inputs = Input(shape=(None, roi_size[0], roi_size[1], pyramid_feature_size))
    # inputs = Input(shape=(None, None, None, num_association_features))

    conv1 = TimeDistributed(Conv2D(
        filters=pyramid_feature_size,
        **options
    ), name='association_vector_submodel_conv1')(inputs)
    conv2 = TimeDistributed(Conv2D(
        filters=pyramid_feature_size,
        **options
    ), name='association_vector_submodel_conv2')(conv1)
    x = conv2
    x = TimeDistributed(MaxPool2D(
    ), name='association_vector_submodel_pool1')(conv2)

    # Residuals
    for i in range(2):
        x = TimeDistributed(Conv2D(filters=pyramid_feature_size,
                                   kernel_size=3,
                                   padding='valid',
                                   kernel_initializer=normal(mean=0.0, stddev=0.01, seed=None),
                                   bias_initializer='zeros',
                                   activation='relu', 
                                   name='association_vector_residual_conv1_block{}'.format(i)))(x)
        y = TimeDistributed(Conv2D(filters=pyramid_feature_size,
                                   kernel_size=3,
                                   padding='same',
                                   kernel_initializer=normal(mean=0.0, stddev=0.01, seed=None),
                                   bias_initializer='zeros',
                                   activation='relu',
                                   name='association_vector_residual_conv2_block{}'.format(i)))(x)
#         x = TimeDistributed(Add())([x, y])
#         x = TimeDistributed(Activation('relu', name='association_vector_residual_relu_block{}'.format(i)))(x)     
        x = Add(name='association_vector_residual_add_block{}'.format(i))([x, y])
        x = Activation('relu', name='association_vector_residual_relu_block{}'.format(i))(x)     


    y = TimeDistributed(AveragePooling2D(pool_size=3,
                        name='association_vector_averagepooling'))(x)
#     y = TimeDistributed(Flatten(data_format='channels_last',
#                                 name='association_vector_flatten'))(y)
#     y = TimeDistributed(GlobalAveragePooling2D(data_format='channels_last',
#                               name='association_vector_globalavgpooling'))(x)
#     print("GlobalAveragePooling3D.shape: ", y.shape)
    
    outputs = TimeDistributed(Dense(num_association_features,
                    activation='softmax',
#                     kernel_initializer='he_normal',
                    name='association_vector_dense_output'))(y)
    outputs = Lambda(lambda x: tf.squeeze(x, axis=[2, 3]))(outputs)

    print("outputs.shape", outputs.shape)
        
    return Model(inputs=inputs, outputs=outputs, name=name)


def retinanet_mask(inputs,
                   backbone_dict,
                   num_classes,
                   num_association_features,
                   frames_per_batch=1,
                   backbone_levels=['C3', 'C4', 'C5'],
                   pyramid_levels=['P3', 'P4', 'P5', 'P6', 'P7'],
                   retinanet_model=None,
                   anchor_params=None,
                   nms=True,
                   panoptic=False,
                   use_assoc_head=False,
                   class_specific_filter=True,
                   crop_size=(14, 14),
                   mask_size=(28, 28),
                   name='retinanet-mask',
                   roi_submodels=None,
                   max_detections=100,
                   score_threshold=0.05,
                   nms_threshold=0.5,
                   mask_dtype=K.floatx(),
                   **kwargs):
    """Construct a RetinaNet mask model on top of a retinanet bbox model.
    Uses the retinanet bbox model and appends layers to compute masks.

    Args:
        inputs (tensor): List of tensorflow.keras.layers.Input.
            The first input is the image, the second input the blob of masks.
        num_classes (int): Integer, number of classes to classify.
        retinanet_model (tensorflow.keras.Model): RetinaNet model that predicts
            regression and classification values.
        anchor_params (AnchorParameters): Struct containing anchor parameters.
        nms (bool): Whether to use NMS.
        class_specific_filter (bool): Use class specific filtering.
        roi_submodels (list): Submodels for processing ROIs.
        name (str): Name of the model.
        mask_dtype (str): Dtype to use for mask tensors.
        kwargs (dict): Additional kwargs to pass to the retinanet bbox model.

    Returns:
        tensorflow.keras.Model: Model with inputs as input and as output
            the output of each submodel for each pyramid level and the
            detections. The order is as defined in submodels.

            ```
            [
                regression, classification, other[0], ...,
                boxes_masks, boxes, scores, labels, masks, other[0], ...
            ]
            ```

    """
    if anchor_params is None:
        anchor_params = AnchorParameters.default

    if roi_submodels is None:
        retinanet_dtype = K.floatx()
        K.set_floatx(mask_dtype)
        roi_submodels = default_roi_submodels(
            num_classes, num_association_features, crop_size, mask_size,
            frames_per_batch, mask_dtype, retinanet_dtype)
        K.set_floatx(retinanet_dtype)

    image = inputs
    image_shape = Shape()(image)

    if retinanet_model is None:
        retinanet_model = retinanet(
            inputs=image,
            backbone_dict=backbone_dict,
            num_classes=num_classes,
            backbone_levels=backbone_levels,
            pyramid_levels=pyramid_levels,
            panoptic=panoptic,
            num_anchors=anchor_params.num_anchors(),
            frames_per_batch=frames_per_batch,
            **kwargs
        )

    # parse outputs
    regression = retinanet_model.outputs[0]
    classification = retinanet_model.outputs[1]

    if panoptic:
        # Determine the number of semantic heads
        n_semantic_heads = len([1 for layer in retinanet_model.layers if 'semantic' in layer.name])

        # The  panoptic output should not be sent to filter detections
        other = retinanet_model.outputs[2:-n_semantic_heads]
        semantic = retinanet_model.outputs[-n_semantic_heads:]
    else:
        other = retinanet_model.outputs[2:]


    features = [retinanet_model.get_layer(name).output
                for name in pyramid_levels]

    # build boxes
    anchors = __build_anchors(anchor_params, features,
                              frames_per_batch=frames_per_batch)
    boxes = RegressBoxes(name='boxes')([anchors, regression])
    boxes = ClipBoxes(name='clipped_boxes')([image, boxes])

    # filter detections (apply NMS / score threshold / select top-k)
    detections = FilterDetections(
        nms=nms,
        nms_threshold=nms_threshold,
        score_threshold=score_threshold,
        class_specific_filter=class_specific_filter,
        max_detections=max_detections,
        name='filtered_detections'
    )([boxes, classification] + other)

    # split up in known outputs and "other"
    boxes = detections[0]
    scores = detections[1]

    # get the region of interest features
    #
    # roi_input = [image_shape, boxes, classification] + features
    # rois = _RoiAlign(crop_size=crop_size)(roi_input)

    fpn = features[0]
    fpn = UpsampleLike()([fpn, image])
    rois = RoiAlign(crop_size=crop_size)([boxes, fpn])

    # execute trackrcnn submodels
    trackrcnn_outputs = [submodel(rois) for _, submodel in roi_submodels]
    print("trackrcnn_outputs:")
    for x in trackrcnn_outputs:
        print(x.name, x.shape, x.dtype)

    # concatenate boxes for loss computation
    trainable_outputs = [ConcatenateBoxes(name=name)([boxes, output])
                         for (name, _), output in zip(
                             roi_submodels, trackrcnn_outputs)]
    
    trainable_outputs[-1] = Concatenate(name='association_features_cat')([trainable_outputs[-1], trainable_outputs[-2]])


    print("roi_submodels names in roi_submodels:")
    for (name, _) in roi_submodels:
        print(name)
    
    print("trainable_outputs:")
    for x in trainable_outputs:
        print(x.name, x.shape, x.dtype)


    outputs = [regression, classification] + other + trainable_outputs + \
        detections + trackrcnn_outputs

    if panoptic:
        print("is panoptic")
        outputs += list(semantic)

    print("outputs:", outputs)

    model = Model(inputs=inputs, outputs=outputs, name=name)
    model.backbone_levels = backbone_levels
    model.pyramid_levels = pyramid_levels

    return model

def RetinaMask(backbone,
               num_classes,
               input_shape,
               inputs=None,
               backbone_levels=['C3', 'C4', 'C5'],
               pyramid_levels=['P3', 'P4', 'P5', 'P6', 'P7'],
               norm_method='whole_image',
               location=False,
               use_imagenet=False,
               crop_size=(14, 14),
               pooling=None,
               mask_dtype=K.floatx(),
               required_channels=3,
               frames_per_batch=1,
               use_assoc_head=False,
               num_association_features=2,
               **kwargs):
    """Constructs a mrcnn model using a backbone from keras-applications.

    Args:
        backbone (str): Name of backbone to use.
        num_classes (int): Number of classes to classify.
        input_shape (tuple): The shape of the input data.
        weights (str): one of None (random initialization),
            'imagenet' (pre-training on ImageNet),
            or the path to the weights file to be loaded.
        pooling (str): optional pooling mode for feature extraction
            when include_top is False.
            - None means that the output of the model will be
                the 4D tensor output of the
                last convolutional layer.
            - 'avg' means that global average pooling
                will be applied to the output of the
                last convolutional layer, and thus
                the output of the model will be a 2D tensor.
            - 'max' means that global max pooling will
                be applied.
        required_channels (int): The required number of channels of the
            backbone.  3 is the default for all current backbones.

    Returns:
        tensorflow.keras.Model: RetinaNet model with a backbone.
    """
    channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
    if inputs is None:
        if frames_per_batch > 1:
            if channel_axis == 1:
                input_shape_with_time = tuple(
                    [input_shape[0], frames_per_batch] + list(input_shape)[1:])
            else:
                input_shape_with_time = tuple(
                    [frames_per_batch] + list(input_shape))
            inputs = Input(shape=input_shape_with_time, name='image_input')
        else:
            inputs = Input(shape=input_shape, name='image_input')

    if location:
        if frames_per_batch > 1:
            # TODO: TimeDistributed is incompatible with channels_first
            loc = TimeDistributed(Location2D(in_shape=input_shape))(inputs)
        else:
            loc = Location2D(in_shape=input_shape)(inputs)
        concat = Concatenate(axis=channel_axis)([inputs, loc])
    else:
        concat = inputs

    # force the channel size for backbone input to be `required_channels`
    if frames_per_batch > 1:
        norm = TimeDistributed(ImageNormalization2D(norm_method=norm_method))(concat)
        fixed_inputs = TimeDistributed(TensorProduct(required_channels))(norm)
    else:
        norm = ImageNormalization2D(norm_method=norm_method)(concat)
        fixed_inputs = TensorProduct(required_channels)(norm)

    # force the input shape
    axis = 0 if K.image_data_format() == 'channels_first' else -1
    fixed_input_shape = list(input_shape)
    fixed_input_shape[axis] = required_channels
    fixed_input_shape = tuple(fixed_input_shape)

    model_kwargs = {
        'include_top': False,
        'weights': None,
        'input_shape': fixed_input_shape,
        'pooling': pooling
    }

    _, backbone_dict = get_backbone(backbone, fixed_inputs,
                                    use_imagenet=use_imagenet,
                                    frames_per_batch=frames_per_batch,
                                    return_dict=True, **model_kwargs)

    # create the full model
    return retinanet_mask(
        inputs=inputs,
        num_classes=num_classes,
        backbone_dict=backbone_dict,
        crop_size=crop_size,
        backbone_levels=backbone_levels,
        pyramid_levels=pyramid_levels,
        name='{}_retinanet_mask'.format(backbone),
        mask_dtype=mask_dtype,
        frames_per_batch=frames_per_batch,
        use_assoc_head=use_assoc_head,
        num_association_features=num_association_features,
        **kwargs)

In [9]:
from deepcell import model_zoo

# Pass frames_per_batch > 1 to enable 3D mode!
model = RetinaMask(
    backbone=backbone,
    input_shape=X_train.shape[2:],
    frames_per_batch=fpb,
    class_specific_filter=False,
    panoptic=False,
    num_classes=num_classes,
    backbone_levels=backbone_levels,
    pyramid_levels=pyramid_levels,
    anchor_params=anchor_params,
    use_assoc_head=True,
)

prediction_model = model

W0329 00:31:46.013994 140505964918592 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0329 00:32:02.784235 140505964918592 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/initializers.py:143: calling RandomNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


outputs.shape (?, ?, 2)


W0329 00:32:14.947195 140505964918592 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py:255: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0329 00:32:16.302996 140505964918592 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py:255: calling crop_and_resize_v1 (from tensorflow.python.ops.image_ops_impl) with box_ind is deprecated and will be removed in a future version.
Instructions for updating:
box_ind is deprecated, use box_indices instead


trackrcnn_outputs:
mask_submodel/Reshape_1:0 (?, ?, ?, 28, 28, 1) <dtype: 'float32'>
final_detection_submodel/Reshape_1:0 (?, ?, ?, 1) <dtype: 'float32'>
assoc_head_submodel/Reshape_1:0 (?, ?, ?, 2) <dtype: 'float32'>
roi_submodels names in roi_submodels:
masks
final_detection
association_features
trainable_outputs:
masks/concat:0 (?, ?, 100, ?) <dtype: 'float32'>
final_detection/concat:0 (?, ?, 100, ?) <dtype: 'float32'>
association_features_cat/concat:0 (?, ?, 100, ?) <dtype: 'float32'>
outputs: [<tf.Tensor 'regression/concat:0' shape=(?, 3, ?, 4) dtype=float32>, <tf.Tensor 'classification/concat:0' shape=(?, 3, ?, 1) dtype=float32>, <tf.Tensor 'masks/concat:0' shape=(?, ?, 100, ?) dtype=float32>, <tf.Tensor 'final_detection/concat:0' shape=(?, ?, 100, ?) dtype=float32>, <tf.Tensor 'association_features_cat/concat:0' shape=(?, ?, 100, ?) dtype=float32>, <tf.Tensor 'filtered_detections/Reshape_2:0' shape=(?, ?, 100, 4) dtype=float32>, <tf.Tensor 'filtered_detections/Reshape_3:0' shape

# model.summary()

In [10]:
model.compile(loss=loss, optimizer=optimizer)

W0329 00:32:18.017716 140505964918592 training_utils.py:1101] Output filtered_detections missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to filtered_detections.
W0329 00:32:18.019028 140505964918592 training_utils.py:1101] Output filtered_detections_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to filtered_detections_1.
W0329 00:32:18.019932 140505964918592 training_utils.py:1101] Output filtered_detections_2 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to filtered_detections_2.
W0329 00:32:18.020729 140505964918592 training_utils.py:1101] Output mask_submodel missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to mask_submodel.
W0329 00

new_y_pred_shape (?, ?, ?)
new_y_true_shape (?, ?, ?)
n_detections Tensor("loss/association_features_cat_loss/strided_slice_18:0", shape=(), dtype=int32)
boxes shape (?, ?, ?)
annotations shape (?, ?, ?)
assoc_heads shape (?, ?, ?)
top_indices_shape [None]
top_indices_shape [3, None, 2]
gather filtered_y_pred shape (3, ?, ?)
filtered_y_pred shape (3, ?, ?)
y_true shape (?, ?, ?)


# Train RetinaMask Model

### Training Parameters

In [15]:
# Copyright 2016-2019 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image generators for training convolutional neural networks."""

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import os

import numpy as np

from skimage.measure import regionprops
from skimage.segmentation import clear_border

from tensorflow.python.keras import backend as K
from tensorflow.python.keras.preprocessing.image import array_to_img
from tensorflow.python.keras.preprocessing.image import Iterator
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.keras.utils import to_categorical

from deepcell.utils.retinanet_anchor_utils import anchor_targets_bbox
from deepcell.utils.retinanet_anchor_utils import anchors_for_shape
from deepcell.utils.retinanet_anchor_utils import guess_shapes

from deepcell.image_generators import _transform_masks
from deepcell.image_generators import ImageFullyConvDataGenerator
from deepcell.image_generators import MovieDataGenerator

class RetinaNetGenerator(ImageFullyConvDataGenerator):
    """Generates batches of tensor image data with real-time data augmentation.
    The data will be looped over (in batches).

    Args:
        featurewise_center: boolean, set input mean to 0 over the dataset,
            feature-wise.
        samplewise_center: boolean, set each sample mean to 0.
        featurewise_std_normalization: boolean, divide inputs by std
            of the dataset, feature-wise.
        samplewise_std_normalization: boolean, divide each input by its std.
        zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
        zca_whitening: boolean, apply ZCA whitening.
        rotation_range: int, degree range for random rotations.
        width_shift_range: float, 1-D array-like or int
            float: fraction of total width, if < 1, or pixels if >= 1.
            1-D array-like: random elements from the array.
            int: integer number of pixels from interval
                (-width_shift_range, +width_shift_range)
            With width_shift_range=2 possible values are ints [-1, 0, +1],
            same as with width_shift_range=[-1, 0, +1],
            while with width_shift_range=1.0 possible values are floats in
            the interval [-1.0, +1.0).
        shear_range: float, shear Intensity
            (Shear angle in counter-clockwise direction in degrees)
        zoom_range: float or [lower, upper], Range for random zoom.
            If a float, [lower, upper] = [1-zoom_range, 1+zoom_range].
        channel_shift_range: float, range for random channel shifts.
        fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}.
            Default is 'nearest'. Points outside the boundaries of the input
            are filled according to the given mode:
                'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
                'nearest':  aaaaaaaa|abcd|dddddddd
                'reflect':  abcddcba|abcd|dcbaabcd
                'wrap':  abcdabcd|abcd|abcdabcd
        cval: float or int, value used for points outside the boundaries
            when fill_mode = "constant".
        horizontal_flip: boolean, randomly flip inputs horizontally.
        vertical_flip: boolean, randomly flip inputs vertically.
        rescale: rescaling factor. Defaults to None. If None or 0, no rescaling
            is applied, otherwise we multiply the data by the value provided
            (before applying any other transformation).
        preprocessing_function: function that will be implied on each input.
            The function will run after the image is resized and augmented.
            The function should take one argument:
            one image (Numpy tensor with rank 3),
            and should output a Numpy tensor with the same shape.
        data_format: One of {"channels_first", "channels_last"}.
            "channels_last" mode means that the images should have shape
                (samples, height, width, channels),
            "channels_first" mode means that the images should have shape
                (samples, channels, height, width).
            It defaults to the image_data_format value found in your
                Keras config file at "~/.keras/keras.json".
            If you never set it, then it will be "channels_last".
        validation_split: float, fraction of images reserved for validation
            (strictly between 0 and 1).
    """

    def flow(self,
             train_dict,
             compute_shapes=guess_shapes,
             min_objects=3,
             num_classes=1,
             clear_borders=False,
             include_masks=False,
             include_final_detection_layer=False,
             panoptic=False,
             transforms=['watershed'],
             transforms_kwargs={},
             assoc_head=False,
             anchor_params=None,
             pyramid_levels=['P3', 'P4', 'P5', 'P6', 'P7'],
             batch_size=32,
             shuffle=False,
             seed=None,
             save_to_dir=None,
             save_prefix='',
             save_format='png'):
        """Generates batches of augmented/normalized data with given arrays.

        Args:
            train_dict: dictionary of X and y tensors. Both should be rank 4.
            compute_shapes: function to determine the shapes of the anchors
            min_classes: images with fewer than 'min_objects' are ignored
            num_classes: number of classes to predict
            clear_borders: boolean, whether to use clear_border on y.
            include_masks: boolean, train on mask data (MaskRCNN).
            batch_size: int (default: 1).
            shuffle: boolean (default: True).
            seed: int (default: None).
            save_to_dir: None or str (default: None).
                This allows you to optionally specify a directory
                to which to save the augmented pictures being generated
                (useful for visualizing what you are doing).
            save_prefix: str (default: ""). Prefix to use for filenames of
                saved pictures (only relevant if save_to_dir is set).
            save_format: one of "png", "jpeg". Default: "png".
                (only relevant if save_to_dir is set)

        Returns:
            An Iterator yielding tuples of (x, y) where x is a numpy array
            of image data and y is a numpy array of labels of the same shape.
        """
        return RetinaNetIterator(
            train_dict,
            self,
            compute_shapes=compute_shapes,
            min_objects=min_objects,
            num_classes=num_classes,
            clear_borders=clear_borders,
            include_masks=include_masks,
            include_final_detection_layer=include_final_detection_layer,
            panoptic=panoptic,
            transforms=transforms,
            transforms_kwargs=transforms_kwargs,
            assoc_head=assoc_head,
            anchor_params=anchor_params,
            pyramid_levels=pyramid_levels,
            batch_size=batch_size,
            shuffle=shuffle,
            seed=seed,
            data_format=self.data_format,
            save_to_dir=save_to_dir,
            save_prefix=save_prefix,
            save_format=save_format)


class RetinaNetIterator(Iterator):
    """Iterator yielding data from Numpy arrayss (X and y).

    Adapted from https://github.com/fizyr/keras-retinanet.

    Args:
        train_dict: dictionary consisting of numpy arrays for X and y.
        image_data_generator: Instance of ImageDataGenerator
            to use for random transformations and normalization.
        compute_shapes: functor for generating shapes, based on the model.
        min_objects: Integer, image with fewer than min_objects are ignored.
        num_classes: Integer, number of classes for classification.
        clear_borders: Boolean, whether to call clear_border on y.
        include_masks: Boolean, whether to yield mask data.
        batch_size: Integer, size of a batch.
        shuffle: Boolean, whether to shuffle the data between epochs.
        seed: Random seed for data shuffling.
        data_format: String, one of 'channels_first', 'channels_last'.
        save_to_dir: Optional directory where to save the pictures
            being yielded, in a viewable format. This is useful
            for visualizing the random transformations being
            applied, for debugging purposes.
        save_prefix: String prefix to use for saving sample
            images (if save_to_dir is set).
        save_format: Format to use for saving sample images
            (if save_to_dir is set).
    """

    def __init__(self,
                 train_dict,
                 image_data_generator,
                 compute_shapes=guess_shapes,
                 anchor_params=None,
                 pyramid_levels=['P3', 'P4', 'P5', 'P6', 'P7'],
                 min_objects=3,
                 num_classes=1,
                 clear_borders=False,
                 include_masks=False,
                 panoptic=False,
                 include_final_detection_layer=False,
                 transforms=['watershed'],
                 transforms_kwargs={},
                 assoc_head=False,
                 batch_size=32,
                 shuffle=False,
                 seed=None,
                 data_format='channels_last',
                 save_to_dir=None,
                 save_prefix='',
                 save_format='png'):
        X, y = train_dict['X'], train_dict['y']

        if X.shape[0] != y.shape[0]:
            raise ValueError('Training batches and labels should have the same'
                             'length. Found X.shape: {} y.shape: {}'.format(
                                 X.shape, y.shape))

        if X.ndim != 4:
            raise ValueError('Input data in `RetinaNetIterator` '
                             'should have rank 4. You passed an array '
                             'with shape', X.shape)

        self.x = np.asarray(X, dtype=K.floatx())
        self.y = np.asarray(y, dtype='int32')

        # `compute_shapes` changes based on the model backbone.
        self.compute_shapes = compute_shapes
        self.anchor_params = anchor_params
        self.pyramid_levels = [int(l[1:]) for l in pyramid_levels]
        self.min_objects = min_objects
        self.num_classes = num_classes
        self.include_masks = include_masks
        self.include_final_detection_layer = include_final_detection_layer
        self.panoptic = panoptic
        self.transforms = transforms
        self.transforms_kwargs = transforms_kwargs
        self.assoc_head = assoc_head
        self.channel_axis = 3 if data_format == 'channels_last' else 1
        self.image_data_generator = image_data_generator
        self.data_format = data_format
        self.save_to_dir = save_to_dir
        self.save_prefix = save_prefix
        self.save_format = save_format

        self.y_semantic_list = []  # optional semantic segmentation targets

        # Add semantic segmentation targets if panoptic segmentation
        # flag is True
        if panoptic:
            # Create a list of all the semantic targets. We need to be able
            # to have multiple semantic heads
            # Add all the keys that contain y_semantic
            for key in train_dict:
                if 'y_semantic' in key:
                    self.y_semantic_list.append(train_dict[key])

            # Add transformed masks
            for transform in transforms:
                transform_kwargs = transforms_kwargs.get(transform, dict())
                y_transform = _transform_masks(y, transform,
                                               data_format=data_format,
                                               **transform_kwargs)
                y_transform = np.asarray(y_transform, dtype='int32')
                self.y_semantic_list.append(y_transform)

        invalid_batches = []
        # Remove images with small numbers of cells
        for b in range(self.x.shape[0]):
            y_batch = np.squeeze(self.y[b], axis=self.channel_axis - 1)
            y_batch = clear_border(y_batch) if clear_borders else y_batch
            y_batch = np.expand_dims(y_batch, axis=self.channel_axis - 1)

            self.y[b] = y_batch

            if len(np.unique(self.y[b])) - 1 < self.min_objects:
                invalid_batches.append(b)

        invalid_batches = np.array(invalid_batches, dtype='int')

        if invalid_batches.size > 0:
            logging.warning('Removing %s of %s images with fewer than %s '
                            'objects.', invalid_batches.size, self.x.shape[0],
                            self.min_objects)

        self.y = np.delete(self.y, invalid_batches, axis=0)
        self.x = np.delete(self.x, invalid_batches, axis=0)

        self.y_semantic_list = [np.delete(y, invalid_batches, axis=0)
                                for y in self.y_semantic_list]

        super(RetinaNetIterator, self).__init__(
            self.x.shape[0], batch_size, shuffle, seed)

    def filter_annotations(self, image, annotations):
        """Filter annotations by removing those that are outside of the
        image bounds or whose width/height < 0.

        Args:
            image: ndarray, the raw image data.
            annotations: dict of annotations including labels and bboxes
        """
        row_axis = 1 if self.data_format == 'channels_first' else 0
        invalid_indices = np.where(
            (annotations['bboxes'][:, 2] <= annotations['bboxes'][:, 0]) |
            (annotations['bboxes'][:, 3] <= annotations['bboxes'][:, 1]) |
            (annotations['bboxes'][:, 0] < 0) |
            (annotations['bboxes'][:, 1] < 0) |
            (annotations['bboxes'][:, 2] > image.shape[row_axis + 1]) |
            (annotations['bboxes'][:, 3] > image.shape[row_axis])
        )[0]

        # delete invalid indices
        if invalid_indices.size > 0:
            logging.warn('Image with shape {} contains the following invalid '
                         'boxes: {}.'.format(
                             image.shape,
                             annotations['bboxes'][invalid_indices, :]))

            for k in annotations.keys():
                filtered = np.delete(annotations[k], invalid_indices, axis=0)
                annotations[k] = filtered
        return annotations

    def load_annotations(self, y):
        """Generate bounding box and label annotations for a tensor

        Args:
            y: tensor to annotate

        Returns:
            dict: annotations of bboxes and labels
        """
        labels, bboxes, masks = [], [], []
        for prop in regionprops(np.squeeze(y.astype('int'))):
            y1, x1, y2, x2 = prop.bbox
            bboxes.append([x1, y1, x2, y2])
            labels.append(0)  # boolean object detection
            masks.append(np.where(y == prop.label, 1, 0))

        labels = np.array(labels)
        bboxes = np.array(bboxes)
        masks = np.array(masks).astype('uint8')

        # reshape bboxes in case it is empty.
        bboxes = np.reshape(bboxes, (bboxes.shape[0], 4))

        annotations = {'labels': labels, 'bboxes': bboxes}
        if self.include_masks:
            annotations['masks'] = masks

        annotations = self.filter_annotations(y, annotations)
        return annotations

    def _get_batches_of_transformed_samples(self, index_array):
        batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]))

        batch_y_semantic_list = []
        for y_sem in self.y_semantic_list:
            shape = tuple([len(index_array)] + list(y_sem.shape[1:]))
            batch_y_semantic_list.append(np.zeros(shape, dtype=y_sem.dtype))

        annotations_list = []

        max_shape = []

        for i, j in enumerate(index_array):
            x = self.x[j]
            y = self.y[j]

            y_semantic_list = [y_sem[j] for y_sem in self.y_semantic_list]

            # Apply transformation
            x, y_list = self.image_data_generator.random_transform(
                x, [y] + y_semantic_list)

            y = y_list[0]
            y_semantic_list = y_list[1:]

            # Find max shape of image data.  Used for masking.
            if not max_shape:
                max_shape = list(x.shape)
            else:
                for k in range(len(x.shape)):
                    if x.shape[k] > max_shape[k]:
                        max_shape[k] = x.shape[k]

            # Get the bounding boxes from the transformed masks!
            annotations = self.load_annotations(y)
            annotations_list.append(annotations)

            x = self.image_data_generator.standardize(x)

            batch_x[i] = x

            for k, y_sem in enumerate(y_semantic_list):
                batch_y_semantic_list[k][i] = y_sem

        anchors = anchors_for_shape(
            batch_x.shape[1:],
            pyramid_levels=self.pyramid_levels,
            anchor_params=self.anchor_params,
            shapes_callback=self.compute_shapes)

        regressions, labels = anchor_targets_bbox(
            anchors,
            batch_x,
            annotations_list,
            self.num_classes)

        max_shape = tuple(max_shape)  # was a list for max shape indexing

        print("annotations_list: ", annotations_list)          

        if self.include_masks:
            # masks_batch has shape: (batch size, max_annotations,
            #     bbox_x1 + bbox_y1 + bbox_x2 + bbox_y2 + label +
            #     width + height + max_image_dimension)
            max_annotations = max(len(a['masks']) for a in annotations_list)
            masks_batch_shape = (len(index_array), max_annotations,
                                 5 + 2 + max_shape[0] * max_shape[1])
            masks_batch = np.zeros(masks_batch_shape, dtype=K.floatx())

            for i, ann in enumerate(annotations_list):
                masks_batch[i, :ann['bboxes'].shape[0], :4] = ann['bboxes']
                masks_batch[i, :ann['labels'].shape[0], 4] = ann['labels']
                masks_batch[i, :, 5] = max_shape[1]  # width
                masks_batch[i, :, 6] = max_shape[0]  # height

                # add flattened mask
                for j, mask in enumerate(ann['masks']):
                    masks_batch[i, j, 7:] = mask.flatten()

        if self.save_to_dir:
            for i, j in enumerate(index_array):
                if self.data_format == 'channels_first':
                    img_x = np.expand_dims(batch_x[i, 0, ...], 0)
                else:
                    img_x = np.expand_dims(batch_x[i, ..., 0], -1)
                img = array_to_img(img_x, self.data_format, scale=True)
                fname = '{prefix}_{index}_{hash}.{format}'.format(
                    prefix=self.save_prefix,
                    index=j,
                    hash=np.random.randint(1e4),
                    format=self.save_format)
                img.save(os.path.join(self.save_to_dir, fname))

        batch_inputs = batch_x
        batch_outputs = [regressions, labels]
        
        if self.assoc_head:
            batch_outputs.append(masks_batch)
        if self.include_masks:
            batch_outputs.append(masks_batch)
        if self.include_final_detection_layer:
            batch_outputs.append(masks_batch)

        batch_outputs.extend(batch_y_semantic_list)

        print("batch_inputs: ", batch_inputs)
        print("batch_outputs: ", batch_outputs)

        return batch_inputs, batch_outputs


    def next(self):
        """For python 2.x. Returns the next batch.
        """
        # Keeps under lock only the mechanism which advances
        # the indexing of each batch.
        with self.lock:
            index_array = next(self.index_generator)
        # The transformation of images is not under thread lock
        # so it can be done in parallel
        return self._get_batches_of_transformed_samples(index_array)


    
class RetinaMovieIterator(Iterator):
    """Iterator yielding data from Numpy arrayss (`X and `y`).

    Adapted from https://github.com/fizyr/keras-retinanet.

    Args:
        train_dict: dictionary consisting of numpy arrays for `X` and `y`.
        image_data_generator: Instance of `ImageDataGenerator`
            to use for random transformations and normalization.
        compute_shapes: functor for generating shapes, based on the model.
        min_objects: Integer, image with fewer than `min_objects` are ignored.
        num_classes: Integer, number of classes for classification.
        clear_borders: Boolean, whether to call `clear_border` on `y`.
        include_masks: Boolean, whether to yield mask data.
        batch_size: Integer, size of a batch.
        shuffle: Boolean, whether to shuffle the data between epochs.
        seed: Random seed for data shuffling.
        data_format: String, one of `channels_first`, `channels_last`.
        save_to_dir: Optional directory where to save the pictures
            being yielded, in a viewable format. This is useful
            for visualizing the random transformations being
            applied, for debugging purposes.
        save_prefix: String prefix to use for saving sample
            images (if `save_to_dir` is set).
        save_format: Format to use for saving sample images
            (if `save_to_dir` is set).
    """

    def __init__(self,
                 train_dict,
                 movie_data_generator,
                 compute_shapes=guess_shapes,
                 anchor_params=None,
                 pyramid_levels=['P3', 'P4', 'P5', 'P6', 'P7'],
                 min_objects=3,
                 num_classes=1,
                 frames_per_batch=2,
                 clear_borders=False,
                 include_masks=False,
                 include_final_detection_layer=False,
                 assoc_head=False,
                 panoptic=False,
                 transforms=['watershed'],
                 transforms_kwargs={},
                 batch_size=32,
                 shuffle=False,
                 seed=None,
                 data_format='channels_last',
                 save_to_dir=None,
                 save_prefix='',
                 save_format='png'):
        X, y = train_dict['X'], train_dict['y']

        if X.shape[0] != y.shape[0]:
            raise ValueError('Training batches and labels should have the same'
                             'length. Found X.shape: {} y.shape: {}'.format(
                                 X.shape, y.shape))

        if X.ndim != 5:
            raise ValueError('Input data in `RetinaNetIterator` '
                             'should have rank 5. You passed an array '
                             'with shape', X.shape)

        self.x = np.asarray(X, dtype=K.floatx())
        self.y = np.asarray(y, dtype='int32')

        # `compute_shapes` changes based on the model backbone.
        self.compute_shapes = compute_shapes
        self.anchor_params = anchor_params
        self.pyramid_levels = [int(l[1:]) for l in pyramid_levels]
        self.min_objects = min_objects
        self.num_classes = num_classes
        self.frames_per_batch = frames_per_batch
        self.include_masks = include_masks
        self.include_final_detection_layer = include_final_detection_layer
        self.assoc_head = assoc_head
        self.panoptic = panoptic
        self.transforms = transforms
        self.transforms_kwargs = transforms_kwargs
        self.channel_axis = 4 if data_format == 'channels_last' else 1
        self.time_axis = 1 if data_format == 'channels_last' else 2
        self.row_axis = 2 if data_format == 'channels_last' else 3
        self.col_axis = 3 if data_format == 'channels_last' else 4
        self.movie_data_generator = movie_data_generator
        self.data_format = data_format
        self.save_to_dir = save_to_dir
        self.save_prefix = save_prefix
        self.save_format = save_format

        self.y_semantic_list = []  # optional semantic segmentation targets

        if X.shape[self.time_axis] - frames_per_batch < 0:
            raise ValueError(
                'The number of frames used in each training batch should '
                'be less than the number of frames in the training data!')

        # Add semantic segmentation targets if panoptic segmentation
        # flag is True
        print("train_dict keys :")
        for key in train_dict:
            print(key)
        if panoptic:
            # Create a list of all the semantic targets. We need to be able
            # to have multiple semantic heads
            # Add all the keys that contain y_semantic
            for key in train_dict:
                if 'y_semantic' in key:
                    self.y_semantic_list.append(train_dict[key])

            # Add transformed masks
            for transform in transforms:
                transform_kwargs = transforms_kwargs.get(transform, dict())
                y_transforms = []
                for time in range(y.shape[self.time_axis]):
                    if data_format == 'channels_first':
                        y_temp = y[:, :, time, ...]
                    else:
                        y_temp = y[:, time, ...]
                    y_temp_transform = _transform_masks(
                        y_temp, transform,
                        data_format=data_format,
                        **transform_kwargs)
                    y_temp_transform = np.asarray(y_temp_transform, dtype='int32')
                    y_transforms.append(y_temp_transform)

                y_transform = np.stack(y_transforms, axis=self.time_axis)
                self.y_semantic_list.append(y_transform)

        invalid_batches = []
        # Remove images with small numbers of cells
        for b in range(self.x.shape[0]):
            y_batch = np.squeeze(self.y[b], axis=self.channel_axis - 1)
            y_batch = clear_border(y_batch) if clear_borders else y_batch
            y_batch = np.expand_dims(y_batch, axis=self.channel_axis - 1)

            self.y[b] = y_batch

            if len(np.unique(self.y[b])) - 1 < self.min_objects:
                invalid_batches.append(b)

        invalid_batches = np.array(invalid_batches, dtype='int')

        if invalid_batches.size > 0:
            logging.warning('Removing %s of %s images with fewer than %s '
                            'objects.', invalid_batches.size, self.x.shape[0],
                            self.min_objects)

        self.y = np.delete(self.y, invalid_batches, axis=0)
        self.x = np.delete(self.x, invalid_batches, axis=0)

        self.y_semantic_list = [np.delete(y, invalid_batches, axis=0)
                                for y in self.y_semantic_list]
        
        super(RetinaMovieIterator, self).__init__(
            self.x.shape[0], batch_size, shuffle, seed)

    def filter_annotations(self, image, annotations):
        """Filter annotations by removing those that are outside of the
        image bounds or whose width/height < 0.

        Args:
            image: ndarray, the raw image data.
            annotations: dict of annotations including `labels` and `bboxes`
        """
        row_axis = 1 if self.data_format == 'channels_first' else 0
        invalid_indices = np.where(
            (annotations['bboxes'][:, 2] <= annotations['bboxes'][:, 0]) |
            (annotations['bboxes'][:, 3] <= annotations['bboxes'][:, 1]) |
            (annotations['bboxes'][:, 0] < 0) |
            (annotations['bboxes'][:, 1] < 0) |
            (annotations['bboxes'][:, 2] > image.shape[row_axis + 1]) |
            (annotations['bboxes'][:, 3] > image.shape[row_axis])
        )[0]

        # delete invalid indices
        if invalid_indices.size > 0:
            logging.warn('Image with shape {} contains the following invalid '
                         'boxes: {}.'.format(
                             image.shape,
                             annotations['bboxes'][invalid_indices, :]))

            for k in annotations.keys():
                filtered = np.delete(annotations[k], invalid_indices, axis=0)
                annotations[k] = filtered
        return annotations

    def load_annotations(self, y):
        """Generate bounding box and label annotations for a tensor

        Args:
            y: tensor to annotate

        Returns:
            annotations: dict of `bboxes` and `labels`
        """
        labels, bboxes, masks = [], [], []
        channel_axis = 1 if self.data_format == 'channels_first' else -1

        max_width = 0
        max_height = 0
        
        for prop in regionprops(np.squeeze(y.astype('int'))):
            y1, x1, y2, x2 = prop.bbox
            if x2-x1 > max_width:
                max_width = x2-x1
            if y2-y1 > max_height:
                max_height = y2-y1
            bboxes.append([x1, y1, x2, y2])
            labels.append(0)  # boolean object detection
            masks.append(np.where(y == prop.label, 1, 0))
            
        labels = np.array(labels)
        bboxes = np.array(bboxes)
        masks = np.array(masks).astype('uint8')

        # reshape bboxes in case it is empty.
        bboxes = np.reshape(bboxes, (bboxes.shape[0], 4))

        annotations = {'labels': labels, 'bboxes': bboxes}

        if self.include_masks:
            annotations['masks'] = masks


        if self.assoc_head:
            print("self.num_unique", self.num_unique)
            print("y.squeeze(channel_axis) shape"m y.squeeze(channel_axis).shape)
            y_transform = to_categorical(y.squeeze(channel_axis), num_classes=self.num_unique)
            if self.data_format == 'channels_first':
                y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)
            N = y_transform.shape[-1]
            # print("y_transform.shape", y_transform.shape)
            
            assoc_head = np.zeros((bboxes.shape[0], N))
            # print("bboxes.shape[0]", bboxes.shape[0])
            for i in range(bboxes.shape[0]):
                x1, y1, x2, y2 = bboxes[i]
                assoc_head[i] = y_transform[int((y1+y2)/2), int((x1+x2)/2)]
                
            annotations['assoc_head'] = assoc_head

        annotations = self.filter_annotations(y, annotations)
        return annotations


    def _get_batches_of_transformed_samples(self, index_array):
        if self.data_format == 'channels_first':
            batch_x = np.zeros((len(index_array),
                                self.x.shape[1],
                                self.frames_per_batch,
                                self.x.shape[3],
                                self.x.shape[4]))
        else:
            batch_x = np.zeros(tuple([len(index_array), self.frames_per_batch] +
                                     list(self.x.shape)[2:]))

        if self.panoptic:
            if self.data_format == 'channels_first':
                batch_y_semantic_list = [np.zeros(tuple([len(index_array),
                                                         y_semantic.shape[1],
                                                         self.frames_per_batch,
                                                         y_semantic.shape[3],
                                                         y_semantic.shape[4]]))
                                         for y_semantic in self.y_semantic_list]
            else:
                batch_y_semantic_list = [
                    np.zeros(tuple([len(index_array), self.frames_per_batch] +
                                   list(y_semantic.shape[2:])))
                    for y_semantic in self.y_semantic_list
                ]

        annotations_list = [[] for _ in range(self.frames_per_batch)]
        self.num_unique = len(np.unique(self.y)) - 1

        max_shape = []

        for i, j in enumerate(index_array):
            last_frame = self.x.shape[self.time_axis] - self.frames_per_batch
            time_start = np.random.randint(0, high=last_frame)
            time_end = time_start + self.frames_per_batch
            times = list(np.arange(time_start, time_end))

            if self.time_axis == 1:
                x = self.x[j, time_start:time_end, ...]
                y = self.y[j, time_start:time_end, ...]
            elif self.time_axis == 2:
                x = self.x[j, :, time_start:time_end, ...]
                y = self.y[j, :, time_start:time_end, ...]

            if self.panoptic:
                if self.time_axis == 1:
                    y_semantic_list = [y_semantic[j, time_start:time_end, ...]
                                       for y_semantic in self.y_semantic_list]
                elif self.time_axis == 2:
                    y_semantic_list = [y_semantic[j, :, time_start:time_end, ...]
                                       for y_semantic in self.y_semantic_list]

            # Apply transformation
            if self.panoptic:
                x, y_list = self.movie_data_generator.random_transform(x, [y] + y_semantic_list)
                y = y_list[0]
                y_semantic_list = y_list[1:]
            else:
                x, y = self.movie_data_generator.random_transform(x, y)

            x = self.movie_data_generator.standardize(x)

            # Find max shape of image data.  Used for masking.
            if not max_shape:
                max_shape = list(x.shape)
            else:
                for k in range(len(x.shape)):
                    if x.shape[k] > max_shape[k]:
                        max_shape[k] = x.shape[k]

            # Get the bounding boxes from the transformed masks!
            for idx_time, time in enumerate(times):
                if self.time_axis == 1:
                    annotations = self.load_annotations(y[idx_time])
                elif self.time_axis == 2:
                    annotations = self.load_annotations(y[:, idx_time, ...])
                annotations_list[idx_time].append(annotations)

            batch_x[i] = x

            if self.panoptic:
                for k in range(len(y_semantic_list)):
                    batch_y_semantic_list[k][i] = y_semantic_list[k]

        if self.data_format == 'channels_first':
            batch_x_shape = [batch_x.shape[1], batch_x.shape[3], batch_x.shape[4]]
        else:
            batch_x_shape = batch_x.shape[2:]

        anchors = anchors_for_shape(
            batch_x_shape,
            pyramid_levels=self.pyramid_levels,
            anchor_params=self.anchor_params,
            shapes_callback=self.compute_shapes)

        regressions_list = []
        labels_list = []

        if self.data_format == 'channels_first':
            batch_x_frame = batch_x[:, :, 0, ...]
        else:
            batch_x_frame = batch_x[:, 0, ...]
        for idx, time in enumerate(times):
            regressions, labels = anchor_targets_bbox(
                anchors,
                batch_x_frame,
                annotations_list[idx],
                self.num_classes)
            regressions_list.append(regressions)
            labels_list.append(labels)

        regressions = np.stack(regressions_list, axis=self.time_axis)
        labels = np.stack(labels_list, axis=self.time_axis)

        # was a list for max shape indexing
        max_shape = tuple([max_shape[self.row_axis - 1],
                           max_shape[self.col_axis - 1]])

        if self.assoc_head:
            flatten = lambda l: [item for sublist in l for item in sublist]
            annotations_list_flatten = flatten(annotations_list)
            max_annotations = max(len(a['assoc_head']) for a in annotations_list_flatten)
            batch_N_unique = max(a['assoc_head'].shape[-1] for a in annotations_list_flatten)
            assoc_heads_batch_shape = (len(index_array), self.frames_per_batch, 
                                       max_annotations, 8 + batch_N_unique)
            assoc_heads_batch = np.zeros(assoc_heads_batch_shape, dtype=K.floatx())
            for idx_time, time in enumerate(times):
                annotations_frame = annotations_list[idx_time]
                for idx_batch, ann in enumerate(annotations_frame):
                    assoc_heads_batch[idx_batch, idx_time, :ann['bboxes'].shape[0], :4] = ann['bboxes']
                    assoc_heads_batch[idx_batch, idx_time, :ann['labels'].shape[0], 4] = ann['labels']
                    assoc_heads_batch[idx_batch, idx_time, :, 5] = max_shape[1]  # width
                    assoc_heads_batch[idx_batch, idx_time, :, 6] = max_shape[0]  # height
                    assoc_heads_batch[idx_batch, idx_time, :, 7] = batch_N_unique  # batch_N_unique
                    
                    # add flattened association head
                    for idx_mask, assoc_head in enumerate(ann['assoc_head']):
                        assoc_heads_batch[idx_batch, idx_time, idx_mask, 8:] = ann['assoc_head'][idx_mask]


        if self.include_masks:
            # masks_batch has shape: (batch size, max_annotations,
            #     bbox_x1 + bbox_y1 + bbox_x2 + bbox_y2 + label +
            #     width + height + max_image_dimension)

            flatten = lambda l: [item for sublist in l for item in sublist]
            annotations_list_flatten = flatten(annotations_list)
            max_annotations = max(len(a['masks']) for a in annotations_list_flatten)
            masks_batch_shape = (len(index_array), self.frames_per_batch, max_annotations,
                                 5 + 2 + max_shape[0] * max_shape[1])
            # print("masks_batch_shape: ", masks_batch_shape)
            masks_batch = np.zeros(masks_batch_shape, dtype=K.floatx())
            
            batch_x_bbox_shape = (len(index_array), self.frames_per_batch, max_annotations, 4)
            batch_x_bbox = np.zeros(batch_x_bbox_shape, dtype=K.floatx())

            for idx_time, time in enumerate(times):
                annotations_frame = annotations_list[idx_time]
                for idx_batch, ann in enumerate(annotations_frame):
                    batch_x_bbox[idx_batch, idx_time, :ann['bboxes'].shape[0], :4] = ann['bboxes']
                    masks_batch[idx_batch, idx_time, :ann['bboxes'].shape[0], :4] = ann['bboxes']
                    masks_batch[idx_batch, idx_time, :ann['labels'].shape[0], 4] = ann['labels']
                    masks_batch[idx_batch, idx_time, :, 5] = max_shape[1]  # width
                    masks_batch[idx_batch, idx_time, :, 6] = max_shape[0]  # height

                    # add flattened mask
                    for idx_mask, mask in enumerate(ann['masks']):
                        masks_batch[idx_batch, idx_time, idx_mask, 7:] = mask.flatten()

        if self.save_to_dir:
            for i, j in enumerate(index_array):
                for frame in range(batch_x.shape[self.time_axis]):
                    if self.time_axis == 2:
                        img = array_to_img(batch_x[i, :, frame], self.data_format, scale=True)
                    else:
                        img = array_to_img(batch_x[i, frame], self.data_format, scale=True)
                    fname = '{prefix}_{index}_{hash}.{format}'.format(
                        prefix=self.save_prefix,
                        index=j,
                        hash=np.random.randint(1e4),
                        format=self.save_format)
                    img.save(os.path.join(self.save_to_dir, fname))

        batch_inputs = batch_x
        batch_outputs = [regressions, labels]
        print("annotations['bboxes'].shape", annotations['bboxes'].shape)
    
        if self.include_masks:
            print("batch_outputs masks shape:", masks_batch.shape)
            batch_outputs.append(masks_batch)
        if self.include_final_detection_layer:
            batch_outputs.append(masks_batch)
        if self.panoptic:
            batch_outputs += batch_y_semantic_list
            print("batch_outputs batch_y_semantic_list shape:", batch_y_semantic_list[0].shape)
        if self.assoc_head:
            batch_inputs = [batch_x, batch_x_bbox]
            batch_outputs.append(assoc_heads_batch)
            print("batch_outputs assoc_heads_batch shape:", assoc_heads_batch.shape)

        return batch_inputs, batch_outputs

    def next(self):
        """For python 2.x. Returns the next batch.
        """
        # Keeps under lock only the mechanism which advances
        # the indexing of each batch.
        with self.lock:
            index_array = next(self.index_generator)
        # The transformation of images is not under thread lock
        # so it can be done in parallel
        return self._get_batches_of_transformed_samples(index_array)


class RetinaMovieDataGenerator(MovieDataGenerator):
    """Generates batches of tensor image data with real-time data augmentation.
    The data will be looped over (in batches).

    Args:
        featurewise_center: boolean, set input mean to 0 over the dataset,
            feature-wise.
        samplewise_center: boolean, set each sample mean to 0.
        featurewise_std_normalization: boolean, divide inputs by std
            of the dataset, feature-wise.
        samplewise_std_normalization: boolean, divide each input by its std.
        zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
        zca_whitening: boolean, apply ZCA whitening.
        rotation_range: int, degree range for random rotations.
        width_shift_range: float, 1-D array-like or int
            float: fraction of total width, if < 1, or pixels if >= 1.
            1-D array-like: random elements from the array.
            int: integer number of pixels from interval
                `(-width_shift_range, +width_shift_range)`
            With `width_shift_range=2` possible values are ints [-1, 0, +1],
            same as with `width_shift_range=[-1, 0, +1]`,
            while with `width_shift_range=1.0` possible values are floats in
            the interval [-1.0, +1.0).
        shear_range: float, shear Intensity
            (Shear angle in counter-clockwise direction in degrees)
        zoom_range: float or [lower, upper], Range for random zoom.
            If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`.
        channel_shift_range: float, range for random channel shifts.
        fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}.
            Default is 'nearest'. Points outside the boundaries of the input
            are filled according to the given mode:
                'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
                'nearest':  aaaaaaaa|abcd|dddddddd
                'reflect':  abcddcba|abcd|dcbaabcd
                'wrap':  abcdabcd|abcd|abcdabcd
        cval: float or int, value used for points outside the boundaries
            when `fill_mode = "constant"`.
        horizontal_flip: boolean, randomly flip inputs horizontally.
        vertical_flip: boolean, randomly flip inputs vertically.
        rescale: rescaling factor. Defaults to None. If None or 0, no rescaling
            is applied, otherwise we multiply the data by the value provided
            (before applying any other transformation).
        preprocessing_function: function that will be implied on each input.
            The function will run after the image is resized and augmented.
            The function should take one argument:
            one image (Numpy tensor with rank 3),
            and should output a Numpy tensor with the same shape.
        data_format: One of {"channels_first", "channels_last"}.
            "channels_last" mode means that the images should have shape
                `(samples, height, width, channels)`,
            "channels_first" mode means that the images should have shape
                `(samples, channels, height, width)`.
            It defaults to the `image_data_format` value found in your
                Keras config file at `~/.keras/keras.json`.
            If you never set it, then it will be "channels_last".
        validation_split: float, fraction of images reserved for validation
            (strictly between 0 and 1).
    """

    def flow(self,
             train_dict,
             batch_size=1,
             frames_per_batch=5,
             compute_shapes=guess_shapes,
             num_classes=1,
             clear_borders=False,
             include_masks=False,
             include_final_detection_layer=False,
             panoptic=False,
             assoc_head=False,
             transforms=['watershed'],
             transforms_kwargs={},
             anchor_params=None,
             pyramid_levels=['P3', 'P4', 'P5', 'P6', 'P7'],
             shuffle=False,
             seed=None,
             save_to_dir=None,
             save_prefix='',
             save_format='png'):
        """Generates batches of augmented/normalized data with given arrays.

        Args:
            train_dict: dictionary of X and y tensors. Both should be rank 5.
            frames_per_batch: int (default: 10).
                size of z axis in generated batches
            batch_size: int (default: 1).
            shuffle: boolean (default: True).
            seed: int (default: None).
            save_to_dir: None or str (default: None).
                This allows you to optionally specify a directory
                to which to save the augmented pictures being generated
                (useful for visualizing what you are doing).
            save_prefix: str (default: `''`). Prefix to use for filenames of
                saved pictures (only relevant if `save_to_dir` is set).
            save_format: one of "png", "jpeg". Default: "png".
                (only relevant if `save_to_dir` is set)

        Returns:
            An Iterator yielding tuples of `(x, y)` where `x` is a numpy array
            of image data and `y` is a numpy array of labels of the same shape.
        """
        return RetinaMovieIterator(
            train_dict,
            self,
            compute_shapes=compute_shapes,
            num_classes=num_classes,
            clear_borders=clear_borders,
            include_masks=include_masks,
            include_final_detection_layer=include_final_detection_layer,
            assoc_head=assoc_head,
            panoptic=panoptic,
            transforms=transforms,
            transforms_kwargs=transforms_kwargs,
            anchor_params=anchor_params,
            pyramid_levels=pyramid_levels,
            batch_size=batch_size,
            frames_per_batch=frames_per_batch,
            shuffle=shuffle,
            seed=seed,
            data_format=self.data_format,
            save_to_dir=save_to_dir,
            save_prefix=save_prefix,
            save_format=save_format)


SyntaxError: invalid syntax (<ipython-input-15-b7bf07d1715f>, line 724)

In [16]:
#from deepcell.image_generators import RetinaMovieDataGenerator

datagen = RetinaMovieDataGenerator(
    rotation_range=180,
    zoom_range=(0.8, 1.2),
    horizontal_flip=True,
    vertical_flip=True)

datagen_val = RetinaMovieDataGenerator()

In [17]:
train_data = datagen.flow(
    train_dict,
    batch_size=1,
    include_masks=True,
    include_final_detection_layer=True,
    assoc_head=True,
    panoptic=False,
    frames_per_batch=fpb,
    pyramid_levels=pyramid_levels,
    anchor_params=anchor_params)

val_data = datagen_val.flow(
    test_dict,
    batch_size=1,
    include_masks=True,
    include_final_detection_layer=True,
    assoc_head=True,
    panoptic=False,
    frames_per_batch=fpb,
    pyramid_levels=pyramid_levels,
    anchor_params=anchor_params)

train_dict keys :
X
y
daughters


W0329 00:47:01.077329 140505964918592 <ipython-input-11-d08e8f86105f>:643] Removing 2 of 192 images with fewer than 3 objects.


train_dict keys :
X
y
daughters


W0329 00:47:02.474605 140505964918592 <ipython-input-11-d08e8f86105f>:643] Removing 1 of 48 images with fewer than 3 objects.


In [14]:
next_data_x, next_data_y = train_data.next()

annotations['bboxes'].shape (8, 4)
batch_outputs masks shape: (1, 3, 8, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 8, 40)


In [None]:
from tensorflow.keras import callbacks
from deepcell.callbacks import RedirectModel, Evaluate

iou_threshold = 0.5
score_threshold = 0.01
max_detections = 100

model.run_eagerly=False

model.fit_generator(
    train_data,
    steps_per_epoch=X_train.shape[0] // batch_size,
    epochs=n_epoch,
    validation_data=val_data,
    validation_steps=X_test.shape[0] // batch_size,
    callbacks=[
        callbacks.LearningRateScheduler(lr_sched),
        callbacks.ModelCheckpoint(
            os.path.join(MODEL_DIR, model_name + '.h5'),
            monitor='val_loss',
            verbose=3,
            save_best_only=True,
            save_weights_only=False),
        RedirectModel(
            Evaluate(val_data,
                     iou_threshold=iou_threshold,
                     score_threshold=score_threshold,
                     max_detections=max_detections,
                     frames_per_batch=fpb,
                     weighted_average=True),
            prediction_model)]
    )

Epoch 1/10
annotations['bboxes'].shape (3, 4)
batch_outputs masks shape: (1, 3, 3, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 3, 40)
  1/192 [..............................] - ETA: 12:07 - loss: 12.4184 - regression_loss: 2.3424 - classification_loss: 0.5087 - masks_loss: 0.2237 - final_detection_loss: 0.5705 - association_features_cat_loss: 8.7732annotations['bboxes'].shape (7, 4)
batch_outputs masks shape: (1, 3, 7, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 7, 40)
  2/192 [..............................] - ETA: 11:18 - loss: 6279804934.2092 - regression_loss: 2.3053 - classification_loss: 0.4409 - masks_loss: 0.2368 - final_detection_loss: 0.4484 - association_features_cat_loss: 6279804928.0000annotations['bboxes'].shape (12, 4)
batch_outputs masks shape: (1, 3, 12, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 12, 40)
  3/192 [..............................] - ETA: 11:02 - loss: 4730748718.8061 - regression_loss: 2.3397 - classification_loss: 0.4256 - mas

 23/192 [==>...........................] - ETA: 9:26 - loss: 3075167365.1341 - regression_loss: 2.3892 - classification_loss: 0.3876 - masks_loss: 0.3806 - final_detection_loss: 0.3300 - association_features_cat_loss: 3075166976.0000annotations['bboxes'].shape (9, 4)
batch_outputs masks shape: (1, 3, 9, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 9, 40)
 24/192 [==>...........................] - ETA: 9:22 - loss: 2966405300.9202 - regression_loss: 2.3891 - classification_loss: 0.3874 - masks_loss: 0.3835 - final_detection_loss: 0.3279 - association_features_cat_loss: 2966405120.0000annotations['bboxes'].shape (7, 4)
batch_outputs masks shape: (1, 3, 7, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 7, 40)
 25/192 [==>...........................] - ETA: 9:19 - loss: 2847749089.3653 - regression_loss: 2.3837 - classification_loss: 0.3865 - masks_loss: 0.3811 - final_detection_loss: 0.3281 - association_features_cat_loss: 2847748864.0000annotations['bboxes'].shape (9, 4)
ba

In [16]:
for x in next_data_y:
    print(x.shape)
print("\n")
for x in next_data_x:
    print(x.shape)

(1, 3, 83349, 5)
(1, 3, 83349, 2)
(1, 3, 8, 28035)
(1, 3, 8, 28035)
(1, 3, 8, 40)


(1, 3, 154, 182, 1)
(1, 3, 8, 4)


In [31]:
boxes_list = []
scores_list = []
labels_list = []

frames_per_batch = 3
score_threshold=0.05
max_detections=100

all_detections = [[None for i in range(train_data.num_classes)]
                          for j in range(batch_boxes.shape[0])]

all_masks = [[None for i in range(train_data.num_classes)]
             for j in range(batch_boxes.shape[0])]

for i in range(train_data.y.shape[0]):
    for j in range(0, train_data.y.shape[1], frames_per_batch):
        movie = train_data.x[[i], j:j + frames_per_batch, ...]
        results = model.predict_on_batch(movie)

        if train_data.panoptic:
            # Add logic for networks that have semantic heads
            pass
        else:
            if train_data.assoc_head:
                boxes = results[-6]
                scores = results[-5]
                labels = results[-4]
                masks = results[-3]
                final_scores = results[-2]
                association_head = results[-1]
            elif (train_data.include_masks and
                    not train_data.include_final_detection_layer):
                boxes = results[-5]
                scores = results[-4]
                labels = results[-3]
                masks = results[-2]
                association_head = results[-1]
            elif (train_data.include_masks and
                  train_data.include_final_detection_layer):
                boxes = results[-6]
                scores = results[-5]
                labels = results[-4]
                masks = results[-3]
                final_scores = results[-2]
                association_head = results[-1]
            else:
                boxes, scores, labels = results[0:3]

            for k in range(frames_per_batch):
                boxes_list.append(boxes[0, k])
                scores_list.append(scores[0, k])
                labels_list.append(labels[0, k])

batch_boxes = np.stack(boxes_list, axis=0)
batch_scores = np.stack(scores_list, axis=0)
batch_labels = np.stack(labels_list, axis=0)

print("batch_boxes.shape", batch_boxes.shape)
print("batch_scores.shape", batch_scores.shape)
print("batch_labels.shape", batch_labels.shape)

for i in range(batch_boxes.shape[0]):
    boxes = batch_boxes[[i]]
    scores = batch_scores[[i]]
    labels = batch_labels[[i]]

    # select indices which have a score above the threshold
    indices = np.where(scores[0, :] > score_threshold)[0]

    # select those scores
    scores = scores[0][indices]

    # find the order with which to sort the scores
    scores_sort = np.argsort(-scores)[:max_detections]

    # select detections
    image_boxes = boxes[0, indices[scores_sort], :]
    image_scores = scores[scores_sort]
    image_labels = labels[0, indices[scores_sort]]
    
    temp0 = np.expand_dims(image_boxes, axis=1)
    temp1 = np.expand_dims(image_scores, axis=1)
    temp2 = np.expand_dims(image_labels, axis=1)
    
#     print("image_boxes shape ", image_boxes.shape)
#     print("image_boxes expanded shape ", temp0.shape)
#     print("image_scores expanded shape ", temp1.shape)
#     print("image_labels expanded shape ", temp2.shape)

    image_detections = np.concatenate([
        image_boxes,
        temp1,
        temp2
    ], axis=1)
    print("image_detections shape", image_detections.shape)

    # copy detections to all_detections
    for label in range(train_data.num_classes):
        imd = image_detections[image_detections[:, -1] == label, :-1]
        all_detections[i][label] = imd

    if train_data.include_masks:
        image_masks = masks[0, :, indices[scores_sort], :, :, image_labels]
        for label in range(train_data.num_classes):
            imm = image_masks[image_detections[:, -1] == label, ...]
            print("i, label", i, label)
            print("imm shape", imm.shape)
            all_masks[i][label] = imm

batch_boxes.shape (5700, 100, 4)
batch_scores.shape (5700, 100)
batch_labels.shape (5700, 100)
image_detections shape (0, 6)
i, label 0 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 6 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 7 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 8 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 9 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 10 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 11 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 12 0
imm shape (0, 3, 28, 28)
image_detections s

image_detections shape (0, 6)
i, label 198 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 199 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 200 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 201 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 202 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 203 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 204 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 205 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 206 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 207 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 208 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 209 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 210 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 211 0
imm shape (0, 3, 28, 28)
image_detections sha

image_detections shape (0, 6)
i, label 421 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 422 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 423 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 424 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 425 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 426 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 427 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 428 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 429 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 430 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 431 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 432 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 433 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 434 0
imm shape (0, 3, 28, 28)
image_detections sha

i, label 640 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 641 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 642 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 643 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 644 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 645 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 646 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 647 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 648 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 649 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 650 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 651 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 652 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 653 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 654 0
imm s

imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 862 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 863 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 864 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 865 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 866 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 867 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 868 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 869 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 870 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 871 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 872 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 873 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 874 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 875 0
imm shape (0, 3, 28,

imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1081 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1082 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1083 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1084 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1085 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1086 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1087 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1088 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1089 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1090 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1091 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1092 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1093 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1094 0
imm sh

i, label 1302 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1303 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1304 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1305 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1306 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1307 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1308 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1309 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1310 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1311 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1312 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1313 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1314 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1315 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, lab

imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1527 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1528 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1529 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1530 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1531 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1532 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1533 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1534 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1535 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1536 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1537 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1538 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1539 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1540 0
imm sh

i, label 1744 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1745 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1746 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1747 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1748 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1749 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1750 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1751 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1752 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1753 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1754 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1755 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1756 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1757 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, lab

image_detections shape (0, 6)
i, label 1961 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1962 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1963 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1964 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1965 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1966 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1967 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1968 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1969 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1970 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1971 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1972 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1973 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 1974 0
imm shape (0, 3, 28, 28)
image_

image_detections shape (0, 6)
i, label 2178 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2179 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2180 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2181 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2182 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2183 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2184 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2185 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2186 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2187 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2188 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2189 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2190 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2191 0
imm shape (0, 3, 28, 28)
image_

imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2398 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2399 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2400 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2401 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2402 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2403 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2404 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2405 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2406 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2407 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2408 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2409 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2410 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2411 0
imm sh

i, label 2617 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2618 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2619 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2620 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2621 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2622 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2623 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2624 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2625 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2626 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2627 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2628 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2629 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2630 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, lab

image_detections shape (0, 6)
i, label 2842 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2843 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2844 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2845 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2846 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2847 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2848 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2849 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2850 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2851 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2852 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2853 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2854 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 2855 0
imm shape (0, 3, 28, 28)
image_

image_detections shape (0, 6)
i, label 3055 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3056 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3057 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3058 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3059 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3060 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3061 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3062 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3063 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3064 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3065 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3066 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3067 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3068 0
imm shape (0, 3, 28, 28)
image_

image_detections shape (0, 6)
i, label 3274 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3275 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3276 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3277 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3278 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3279 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3280 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3281 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3282 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3283 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3284 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3285 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3286 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3287 0
imm shape (0, 3, 28, 28)
image_

image_detections shape (0, 6)
i, label 3495 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3496 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3497 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3498 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3499 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3500 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3501 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3502 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3503 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3504 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3505 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3506 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3507 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3508 0
imm shape (0, 3, 28, 28)
image_

i, label 3698 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3699 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3700 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3701 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3702 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3703 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3704 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3705 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3706 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3707 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3708 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3709 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3710 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3711 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, lab

image_detections shape (0, 6)
i, label 3906 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3907 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3908 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3909 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3910 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3911 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3912 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3913 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3914 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3915 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3916 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3917 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3918 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 3919 0
imm shape (0, 3, 28, 28)
image_

imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4121 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4122 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4123 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4124 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4125 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4126 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4127 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4128 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4129 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4130 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4131 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4132 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4133 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4134 0
imm sh

image_detections shape (0, 6)
i, label 4339 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4340 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4341 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4342 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4343 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4344 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4345 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4346 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4347 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4348 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4349 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4350 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4351 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4352 0
imm shape (0, 3, 28, 28)
image_

image_detections shape (0, 6)
i, label 4556 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4557 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4558 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4559 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4560 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4561 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4562 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4563 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4564 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4565 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4566 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4567 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4568 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4569 0
imm shape (0, 3, 28, 28)
image_

image_detections shape (0, 6)
i, label 4775 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4776 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4777 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4778 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4779 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4780 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4781 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4782 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4783 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4784 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4785 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4786 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4787 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4788 0
imm shape (0, 3, 28, 28)
image_

image_detections shape (0, 6)
i, label 4994 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4995 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4996 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4997 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4998 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 4999 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5000 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5001 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5002 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5003 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5004 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5005 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5006 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5007 0
imm shape (0, 3, 28, 28)
image_

image_detections shape (0, 6)
i, label 5217 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5218 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5219 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5220 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5221 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5222 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5223 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5224 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5225 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5226 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5227 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5228 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5229 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5230 0
imm shape (0, 3, 28, 28)
image_

image_detections shape (0, 6)
i, label 5435 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5436 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5437 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5438 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5439 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5440 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5441 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5442 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5443 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5444 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5445 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5446 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5447 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5448 0
imm shape (0, 3, 28, 28)
image_

i, label 5652 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5653 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5654 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5655 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5656 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5657 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5658 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5659 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5660 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5661 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5662 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5663 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5664 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, label 5665 0
imm shape (0, 3, 28, 28)
image_detections shape (0, 6)
i, lab

In [28]:
print(len(all_detections), len(all_detections[0]))

190 1


# Playground

In [40]:
prediction = model.predict_generator(
    train_data,
    steps=1)

annotations['bboxes'].shape (6, 4)
batch_outputs masks shape: (1, 3, 6, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 6, 41)
annotations['bboxes'].shape (4, 4)
batch_outputs masks shape: (1, 3, 4, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 4, 41)
annotations['bboxes'].shape (10, 4)
batch_outputs masks shape: (1, 3, 10, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 10, 41)
annotations['bboxes'].shape (10, 4)
batch_outputs masks shape: (1, 3, 11, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 11, 41)
annotations['bboxes'].shape (8, 4)
batch_outputs masks shape: (1, 3, 8, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 8, 41)
annotations['bboxes'].shape (10, 4)
batch_outputs masks shape: (1, 3, 10, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 10, 41)
annotations['bboxes'].shape (3, 4)
batch_outputs masks shape: (1, 3, 3, 28035)
batch_outputs assoc_heads_batch shape: (1, 3, 3, 41)
annotations['bboxes'].shape (13, 4)
batch_outputs masks shape: (1, 

In [76]:
def discriminative_instance_loss(y_true, 
                                 y_pred,
                                 delta_v=0.5,
                                 delta_d=1.5,
                                 gamma=1e-3):
    """Discriminative loss between an output tensor and a target tensor.

    Args:
        y_true: A tensor of the same shape as y_pred.
        y_pred: A tensor of the vector embedding

    Returns:
        tensor: Output tensor.
    """

    def temp_norm(ten, axis=None):
        if axis is None:
            axis = 1 if K.image_data_format() == 'channels_first' else K.ndim(ten) - 1
        return K.sqrt(K.epsilon() + K.sum(K.square(ten), axis=axis))

    if K.ndim(y_pred) == 4:
        y_pred_shape = tf.shape(y_pred)
        new_y_pred_shape = [y_pred_shape[0] * y_pred_shape[1],
                            y_pred_shape[2], y_pred_shape[3]]
        y_pred = tf.reshape(y_pred, new_y_pred_shape)
        print("new_y_pred_shape", y_pred.shape)

        y_true_shape = tf.shape(y_true)
        new_y_true_shape = [y_true_shape[0] * y_true_shape[1],
                            y_true_shape[2], y_true_shape[3]]
        y_true = tf.reshape(y_true, new_y_true_shape)
        print("new_y_true_shape", y_true.shape)
    
    # split up the different predicted blobs
    assoc_feature_channel_shape = y_pred_shape[-1] - 5
    frames_per_batch = y_pred_shape[0] * y_pred_shape[1]
    boxes = y_pred[:, :, :4]
    assoc_heads = y_pred[:, :, 4:assoc_feature_channel_shape]
    channel_dim = y_pred.shape[-1]
    final_detection_scores = y_pred[:, :, -1]

    # split up the different blobs
    annotations = y_true[:, :, :4]
    labels = K.cast(y_true[:, :, 4:5], dtype='int32')
    width = K.cast(y_true[0, 0, 5], dtype='int32')
    height = K.cast(y_true[0, 0, 6], dtype='int32')
    max_N = K.cast(y_true[0, 0, 7], dtype='int32')
    assoc_heads_target = y_true[:, :, 8:]
    n_detections = y_true_shape[2]
    print("n_detections", n_detections)
    
    print("boxes shape", boxes.shape)
    print("annotations shape", annotations.shape)
    
    # reshape the assoc_heads back to their original size
    assoc_heads_target = K.reshape(assoc_heads_target, (K.shape(assoc_heads_target)[0],
                                            K.shape(assoc_heads_target)[1], max_N))
    assoc_heads = K.reshape(assoc_heads, (K.shape(assoc_heads)[0], K.shape(assoc_heads)[1], 
                                          K.shape(assoc_heads)[2]))
    print("assoc_heads shape", assoc_heads.shape)

   
    # temp = final_detection_scores[0,...,0]
    temp = final_detection_scores[0,...]
    top_vals, top_indices = tf.math.top_k(temp, k=n_detections, sorted=False)

    top_indices_shape = top_indices.get_shape().as_list()
    print("top_indices_shape", top_indices_shape)
    
    top_indices = tf.stack([top_indices, top_indices], axis=-1)
    frames_per_batch = 3
    top_indices = tf.stack([top_indices for l in range(frames_per_batch)], axis=0)
    top_indices_shape = top_indices.get_shape().as_list()
    print("top_indices_shape", top_indices_shape)
    
    filtered_y_pred = tf.gather_nd(assoc_heads, top_indices, batch_dims=0)
    print("gather filtered_y_pred shape", filtered_y_pred.shape)
    y_pred = filtered_y_pred
    
#     for i in range(top_indices_shape[0]):
#         for j in range(top_indices_shape[1]):
#             filtered_y_pred[i, j, :] = K.eval(assoc_heads[i, top_indices[i, j], :])

#     y_pred = tf.convert_to_tensor(filtered_y_pred, dtype='float32')
    print("filtered_y_pred shape", y_pred.shape)
    print("y_true shape", y_true.shape)
     
    rank = K.ndim(y_pred)
    channel_axis = 1 if K.image_data_format() == 'channels_first' else rank - 1
    axes = [x for x in list(range(rank)) if x != channel_axis]

    # Compute variance loss
    cells_summed = tf.tensordot(y_true, y_pred, axes=[axes, axes])
    n_pixels = K.cast(tf.count_nonzero(y_true, axis=axes), dtype=K.floatx()) + K.epsilon()
    n_pixels_expand = K.expand_dims(n_pixels, axis=1) + K.epsilon()
    mu = tf.divide(cells_summed, n_pixels_expand)

    delta_v = K.constant(delta_v, dtype=K.floatx())
    mu_tensor = tf.tensordot(y_true, mu, axes=[[channel_axis], [0]])
    L_var_1 = y_pred - mu_tensor
    L_var_2 = K.square(K.relu(temp_norm(L_var_1) - delta_v))
    L_var_3 = tf.tensordot(L_var_2, y_true, axes=[axes, axes])
    L_var_4 = tf.divide(L_var_3, n_pixels)
    L_var = K.mean(L_var_4)

    # Compute distance loss
    mu_a = K.expand_dims(mu, axis=0)
    mu_b = K.expand_dims(mu, axis=1)

    diff_matrix = tf.subtract(mu_b, mu_a)
    L_dist_1 = temp_norm(diff_matrix)
    L_dist_2 = K.square(K.relu(K.constant(2 * delta_d, dtype=K.floatx()) - L_dist_1))
    diag = K.constant(0, dtype=K.floatx()) * tf.diag_part(L_dist_2)
    L_dist_3 = tf.matrix_set_diag(L_dist_2, diag)
    L_dist = K.mean(L_dist_3)

    # Compute regularization loss
    L_reg = gamma * temp_norm(mu)
    L = L_var + L_dist + K.mean(L_reg)

    return L

In [77]:
y_pred = tf.convert_to_tensor(prediction[4])
y_true = tf.convert_to_tensor(next_data_y[-1])
final_detection_pred = tf.convert_to_tensor(prediction[3])
y_pred = tf.concat([y_pred, final_detection_pred], axis=-1)
result = discriminative_instance_loss(y_true, y_pred)
print(tf.keras.backend.eval(result))

new_y_pred_shape (3, 100, 11)
new_y_true_shape (3, 9, 41)
n_detections Tensor("strided_slice_248:0", shape=(), dtype=int32)
boxes shape (3, 100, 4)
annotations shape (3, 9, 4)
assoc_heads shape (3, 100, 2)
top_indices_shape [9]
top_indices_shape [3, 9, 2]
gather filtered_y_pred shape (3, 9, 2)
filtered_y_pred shape (3, 9, 2)
y_true shape (3, 9, 41)
7537758000.0


### Model Outputs

### Generator Outputs