## Definition of the TensorFlow model function for the CNN use at each localization step

In [2]:
def cnn_regression_model_fn(features, labels, mode, params):
    """Model function for CNN."""
    # Input Layer
    # Reshape X to 5-D tensor: [batch_size, width, height, depth, channels]
    # Our images are 64 x 64 x 64 
    SIZE_IMAGE = params['image_size']
    
    input_layer = tf.reshape(features["image_data"], [-1, SIZE_IMAGE, SIZE_IMAGE, SIZE_IMAGE, 1])

    # Convolutional Layer #1
    # Computes 64 features using a 3x3x3 filter with ReLU activation and batch normalization
    # Padding is added to preserve width and height.
    # Input Tensor Shape: [batch_size, 64, 64, 64, 1]
    # Output Tensor Shape: [batch_size, 64, 64, 64, 32]
    conv1 = tf.layers.conv3d(
        inputs=input_layer,
        filters=32,
        kernel_size=[3, 3, 3],
        padding="same")
    conv1 = tf.layers.batch_normalization(
        inputs=conv1, 
        training=mode == tf.estimator.ModeKeys.TRAIN)
    conv1 = tf.nn.relu(features=conv1)
    
    # Convolutional Layer #2
    # Computes 64 features using a 3x3x3 filter with ReLU activation and batch normalization
    # Padding is added to preserve width and height.
    # Input Tensor Shape: [batch_size, 64, 64, 64, 32]
    # Output Tensor Shape: [batch_size, 64, 64, 64, 32]
    conv2 = tf.layers.conv3d(
        inputs=conv1,
        filters=32,
        kernel_size=[3, 3, 3],
        padding="same")
    conv2 = tf.layers.batch_normalization(
        inputs=conv2, 
        training=mode == tf.estimator.ModeKeys.TRAIN)
    conv2 = tf.nn.relu(features=conv2)

    # Pooling Layer #1
    # First max pooling layer with a 2x2x2 filter and stride of 2
    # Input Tensor Shape: [batch_size, 64, 64, 64, 32]
    # Output Tensor Shape: [batch_size, 32, 32, 32, 32]
    pool1 = tf.layers.max_pooling3d(inputs=conv1, pool_size=[2, 2, 2], strides=2)

    # Convolutional Layer #3
    # Computes 64 features using a 5x5 filter.
    # Padding is added to preserve width and height.
    # Input Tensor Shape: [batch_size, 32, 32, 32, 32]
    # Output Tensor Shape: [batch_size, 32, 32, 32, 64]
    conv3 = tf.layers.conv3d(
        inputs=pool1,
        filters=64,
        kernel_size=[3, 3, 3],
        padding="same")
    conv3 = tf.layers.batch_normalization(
        inputs=conv3, 
        training=mode == tf.estimator.ModeKeys.TRAIN)
    conv3 = tf.nn.relu(features=conv3)

    # Convolutional Layer #4
    # Computes 64 features using a 5x5 filter.
    # Padding is added to preserve width and height.
    # Input Tensor Shape: [batch_size, 32, 32, 32, 64]
    # Output Tensor Shape: [batch_size, 32, 32, 32, 64]
    conv4 = tf.layers.conv3d(
        inputs=conv3,
        filters=64,
        kernel_size=[3, 3, 3],
        padding="same")
    conv4 = tf.layers.batch_normalization(
        inputs=conv4,
        training=mode == tf.estimator.ModeKeys.TRAIN)
    conv4 = tf.nn.relu(features=conv4)
    
    # Pooling Layer #2
    # Second max pooling layer with a 2x2 filter and stride of 2
    # Input Tensor Shape: [batch_size, 32, 32, 32, 64]
    # Output Tensor Shape: [batch_size, 16, 16, 16, 64]
    pool2 = tf.layers.max_pooling3d(
        inputs=conv2, 
        pool_size=[2, 2, 2], 
        strides=2)

    # Convolutional Layer #5
    # Computes 64 features using a 5x5 filter.
    # Padding is added to preserve width and height.
    # Input Tensor Shape: [batch_size, 16, 16, 16, 64]
    # Output Tensor Shape: [batch_size, 16, 16, 16, 128]
    conv5 = tf.layers.conv3d(
        inputs=pool2,
        filters=128,
        kernel_size=[3, 3, 3],
        padding="same")
    conv5 = tf.layers.batch_normalization(
        inputs=conv5,
        training=mode == tf.estimator.ModeKeys.TRAIN)
    conv5 = tf.nn.relu(features=conv5)

    # Convolutional Layer #6
    # Computes 64 features using a 5x5 filter.
    # Padding is added to preserve width and height.
    # Input Tensor Shape: [batch_size, 16, 16, 16, 128]
    # Output Tensor Shape: [batch_size, 16, 16, 16, 128]
    conv6 = tf.layers.conv3d(
        inputs=conv5,
        filters=128,
        kernel_size=[3, 3, 3],
        padding="same")
    conv6 = tf.layers.batch_normalization(
        inputs=conv6, 
        training=mode == tf.estimator.ModeKeys.TRAIN)
    conv6 = tf.nn.relu(features=conv6)
    
    # Pooling Layer #3
    # Third max pooling layer with a 2x2 filter and stride of 2
    # Input Tensor Shape: [batch_size, 16, 16, 16, 128]
    # Output Tensor Shape: [batch_size, 8, 8, 8, 128]
    pool3 = tf.layers.max_pooling3d(
        inputs=conv6,
        pool_size=[2, 2, 2], 
        strides=2)

    # Flatten tensor 
    # Input Tensor Shape: [batch_size, 8, 8, 8, 128]
    # Output Tensor Shape: [batch_size, 8 * 8 * 8 * 128]
    pool4_flat = tf.reshape(pool3, [-1, 8 * 8 * 8 * 128])

    # Dense Layer
    # Densely connected layer with 1024 neurons
    # Input Tensor Shape: [batch_size, 8 * 8 * 8 * 128]
    # Output Tensor Shape: [batch_size, 512]
    dense1 = tf.layers.dense(inputs=pool4_flat, units=512)
    dense1 = tf.layers.batch_normalization(
      inputs=dense1, training=mode == tf.estimator.ModeKeys.TRAIN)
    dense1 = tf.nn.relu(features=dense1)
    dense1 = tf.layers.dropout(
      inputs=dense1, rate=0.0, training=mode == tf.estimator.ModeKeys.TRAIN)
    

    # Logits layer X, Y, Z
    # Input Tensor Shape: [batch_size, 512]
    # Output Tensor Shape: [batch_size, 3]
    position = tf.layers.dense(inputs=dense1, units=3)

    predictions = {
      # Generate predictions (for PREDICT and EVAL mode)
        "position": position
    }
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Calculate Loss (for both TRAIN and EVAL modes)
    labels = tf.cast(labels, tf.float32)
    average_loss = tf.losses.mean_squared_error(labels, position)
    tf.summary.scalar('average_loss', average_loss)
    merg = tf.summary.merge_all()

    batch_size = tf.shape(labels)[0]
    total_loss = tf.to_float(batch_size) * average_loss

    # Configure the Training Op (for TRAIN mode)    
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdamOptimizer(learning_rate = params['learning_rate'])
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):
            train_op = optimizer.minimize(
                loss=average_loss,
                global_step=tf.train.get_global_step())
        
            return tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, training_chief_hooks=[
                                                tf.train.SummarySaverHook(save_steps=5, 
                                                                          output_dir=params['model_dir'], 
                                                                          summary_op=merg)])

    # Add evaluation metrics (for EVAL mode)
    rms = tf.metrics.root_mean_squared_error(labels, position)
    mae = tf.metrics.mean_absolute_error(labels, position)

    eval_metric_ops = {
        "rms": rms,
        "mae": mae
    }
    return tf.estimator.EstimatorSpec(
      mode=mode, loss=average_loss, eval_metric_ops=eval_metric_ops)


