In [None]:
from __future__ import print_function
import numpy as np
import tensorflow as tf
from load_pvoc_data import load_data, TRAIN_LENGTH

In [None]:
BATCH_SIZE = 32
EPOCHS = 4
VALIDATION_SPLIT = 0.3

In [None]:
def preprocessing(img, lbl):
    #one_hot = tf.one_hot(lbl, 20)
    #summed = tf.reduce_sum(one_hot, axis=-2)
    #multi_hot = tf.where(
    #    tf.equal(summed, 0), tf.zeros_like(summed, dtype=tf.float32), tf.ones_like(summed, dtype=tf.float32)
    #)
    return img, lbl

In [None]:
def train_input_fn():
    train_dataset = tf.data.Dataset.from_generator(
        lambda:load_data("train"),
        (tf.uint8, tf.int32),
        (tf.TensorShape([None, None, 3]), tf.TensorShape([None]))
    )
    train_dataset = train_dataset.shuffle(10000)
    
    val_length = int(VALIDATION_SPLIT * TRAIN_LENGTH * 8)
    val_dataset = train_dataset.take(val_length).apply(
        tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))
    train_dataset = train_dataset.skip(val_length).apply(
        tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE)).repeat()

    return train_dataset, val_dataset

In [None]:
def test_input_fn():
    test_dataset = tf.data.Dataset.from_generator(
        lambda:load_data("test"),
        (tf.uint8, tf.int32),
        (tf.TensorShape([None, None, 3]), tf.TensorShape([None]))
    )
    return test_dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))

In [None]:
def conv_layer(inputs, filters=32, kernel_size=3, strides=1, activation=tf.nn.leaky_relu, batch_normalize=True
               trainable=True):
    x = tf.layers.conv2d(inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, padding='same',
                         trainable=trainable)
    if batch_normalize:
        x = tf.layers.batch_normalization(x, trainable=trainable)
    if activation is not None:
        x = activation(x)
    return x

In [None]:
def residual_block(inputs, filters, trainable=False):
    x = conv_layer(inputs=inputs, filters=filters, kernel_size=1, trainable=trainable)
    x = conv_layer(inputs=inputs, filters=(filters * 2), trainable=trainable)
    return x + inputs

In [None]:
def darknet_block(inputs, filters, repetitions, trainable=False):
    x = conv_layer(inputs=inputs, filters=filters, strides=2, trainable=trainable)
    for i in range(repetitions):
        x = residual_block(x, filters / 2, trainable=trainable)
    return x

In [None]:
def yolo_layer(inputs, anchors):
    indices_w = tf.range(int(inputs.shape[2]))
    indices_h = tf.range(int(inputs.shape[1]))
    x_indices, y_indices = tf.meshgrid(indices_w, indices_h)
    
    for i, anchor in enumerate(anchors):
        inputs[:,:,:,25 * i + 0] = (tf.sigmoid(inputs[:,:,:,25 * i + 0]) + x_indices) / int(inputs.shape[2])    # bx
        inputs[:,:,:,25 * i + 1] = (tf.sigmoid(inputs[:,:,:,25 * i + 0]) + y_indices) / int(inputs.shape[1])    # by
        inputs[:,:,:,25 * i + 2] = (tf.exp(inputs[:,:,:,25 * i + 2]) * anchor[0]) / int(inputs.shape[2])    # bw
        inputs[:,:,:,25 * i + 3] = (tf.exp(inputs[:,:,:,25 * i + 3]) * anchor[1]) / int(inputs.shape[1])    # bh
        inputs[:,:,:,25 * i + 4] = tf.sigmoid(inputs[:,:,:,25 * i + 4])
        
    return inputs

In [None]:
def non_max_suppr(*args):
    for i, arg in enumerate(args):
        for j in range(3):
            arg[:,:,:,25 * j + 0] = arg[:,:,:,25 * j + 1] - arg[:,:,:,25 * j + 3] / 2    # y_min
            arg[:,:,:,25 * j + 1] = arg[:,:,:,25 * j + 0] - arg[:,:,:,25 * j + 2] / 2    # x_min
            arg[:,:,:,25 * j + 2] = arg[:,:,:,25 * j + 1] + arg[:,:,:,25 * j + 3] / 2    # y_max
            arg[:,:,:,25 * j + 3] = arg[:,:,:,26 * j + 0] + arg[:,:,:,25 * j + 2] / 2    # x_max
        flattened = tf.reshape(
            arg,
            (-1, int(arg.shape[1] * arg.shape[2]), int(arg.shape[3]))
        )
        args[i] = tf.concat([flattened[:,:,:25], flattened[:,:,25:50],
                             flattened[:,:,50:]], axis=1)
    args = tf.concat(args, axis=1)
    return tf.map_fn(
        lambda boxes: tf.gather(boxes, tf.image.non_max_suppression(
            boxes[:,:4],
            boxes[:,4],
            6,
            score_threshold=0.5
        )),
        args,
        infer_shape=False
    )

In [None]:
def darknet_model(features, labels, mode):  
    features = tf.cast(features, dtype=tf.float32)
    normalized = tf.map_fn(tf.image.per_image_standardization, features,
                           infer_shape=False)
    
    # Feature extractor: Darknet53
    x = conv_layer(inputs=normalized, filters=32, trainable=False)
    x = darknet_block(x, 64, 1)
    x = darknet_block(x, 128, 2)
    l_36 = darknet_block(x, 256, 8)
    l_61 = darknet_block(l_36, 512, 8)
    x = darknet_block(l_61, 1024, 4, trainable=True)
    
    # YOLO model
    x = conv_layer(x, filters=512, kernel_size=1)
    x = conv_layer(x, filters=1024)
    x = conv_layer(x, filters=512, kernel_size=1)
    x = conv_layer(x, filters=1024)
    l_79 = conv_layer(x, filters=512, kernel_size=1)
    
    x = conv_layer(l_79, filters=1024)
    x = conv_layer(x, filters=75, kernel_size=1, activation=None, batch_normalize=False)
    o_1 = yolo_layer(x, anchors=[(116, 90), (156, 198), (373, 326)])
    
    x = conv_layer(l_79, filters=256, kernel_size=1)
    x = tf.image.resize_images(x, (int(x.shape[1]) * 2, int(x.shape[2]) * 2))
    x = tf.concat([x, l_61], axis=-1)
    x = conv_layer(x, filters=256, kernel_size=1)
    x = conv_layer(x, filters=512)
    x = conv_layer(x, filters=256, kernel_size=1)
    x = conv_layer(x, filters=512)
    l_91 = conv_layer(x, filters=256, kernel_size=1)
    
    x = conv_layer(x, filters=512)
    x = conv_layer(x, filters=75, kernel_size=1, activation=None, batch_normalize=False)
    o_2 = yolo_layer(x, anchors=[(30, 61), (62, 45), (59, 119)])
    
    x = conv_layer(l_91, filters=128, kernel_size=1)
    x = tf.image.resize_images(x, (int(x.shape[1]) * 2, int(x.shape[2]) * 2))
    x = tf.concat([x, l_36], axis=-1)
    x = conv_layer(x, filters=128, kernel_size=1)
    x = conv_layer(x, filters=256)
    x = conv_layer(x, filters=128, kernel_size=1)
    x = conv_layer(x, filters=256)
    x = conv_layer(x, filters=128, kernel_size=1)
    x = conv_layer(x, filters=256)
    x = conv_layer(x, filters=75, kernel_size=1, activation=None, batch_normalize=False)
    o_3 = yolo_layer(x, anchors=[(10, 13), (16, 30), (33, 23)])
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        # Non-maximum suppression to remove overlapping boxes
        output = non_max_suppr([o_1, o_2, o_3])
        predictions = {
            'images': tf.image.draw_bounding_boxes(features, output[:,:,:4]),
            'labels': tf.argmax(output[:,:,5:], axis=-1)
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
    
    # TODO: Loss
    """
    loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=x)
    loss = tf.reduce_mean(loss)
    tf.summary.scalar('loss', loss)
    """
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.contrib.estimator.TowerOptimizer(tf.train.AdamOptimizer(1e-4))
        train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss)

In [None]:
vars_warm = []

In [None]:
warm_start = tf.estimator.WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpdark/', vars_to_warm_start=vars_warm)

In [None]:
model = tf.estimator.Estimator(
    model_fn=tf.contrib.estimator.replicate_model_fn(darknet_model), model_dir='/tmp/tmpdarkyolo',
    warm_start_from=warm_start, config=tf.estimator.RunConfig(
        save_checkpoints_steps=150, save_summary_steps=10, log_step_count_steps=10
    )
)

In [None]:
validation_hook = tf.contrib.learn.monitors.replace_monitors_with_hooks(
    [tf.contrib.learn.monitors.ValidationMonitor(
        input_fn=lambda:train_input_fn()[1], every_n_steps=100, early_stopping_rounds=10
    )],
    model
)[0]

In [None]:
max_steps = int(((1 - VALIDATION_SPLIT) * TRAIN_LENGTH * 8 / BATCH_SIZE) * EPOCHS)
model.train(input_fn=lambda:train_input_fn()[0], hooks=[validation_hook],
            max_steps=max_steps)

In [None]:
print(model.evaluate(input_fn=test_input_fn))