In [None]:
import sys
sys.path.append('../tensorflow_fcos')

In [None]:
import tensorflow as tf
from models.fcos import FCOS
from data.bdd_dataset.dataset import dataset_fn
from utils.visualization import draw_boxes_cv2, imshow
from skimage.io import imread, imsave
import numpy as np
import os
from glob import glob

print('TensorFlow:', tf.__version__)

In [None]:
strategy = tf.distribute.OneDeviceStrategy(device='/gpu:0')
data_dir = os.environ['HOME'] + '/datasets/BDD100k'
H, W = 720, 1280
config = {
    'mode': 'train',
    'distribute_strategy': strategy,
    'image_height': H,
    'image_width': W,
    'num_classes': 10,
    'dataset_fn': dataset_fn,
    'data_dir': data_dir,
    'batch_size': 8,
    'epochs': 25,
    'learning_rate': 5e-4,
    'checkpoint_prefix': 'ckpt',
    'model_dir': '../model_files',
    'tensorboard_log_dir': './logs',
    'log_after': 20,
    'restore_parameters': True
}

In [None]:
fcos = FCOS(config)

In [None]:
fcos.latest_checkpoint

In [None]:
images = glob('/home/antpc/datasets/cityscapes/leftImg8bit/test/*/*')

In [None]:
def ltrb2boxes(centers, ltbr):  
    boxes = tf.concat([
        centers - ltbr[:, :2],
        centers + ltbr[:, 2:]], axis=-1)
    return boxes

def decode_predictions(logits, 
                       score_threshold=0.0, centers=None):
    cls_target = tf.concat(logits[0], axis=1)
    ctr_target = tf.concat(logits[1], axis=1)
    reg_target = tf.concat(logits[2], axis=1)    
    
    cls_target = tf.sigmoid(cls_target)
    ctr_target = tf.sigmoid(ctr_target)
    
    cls_scores = tf.reduce_max(cls_target[0], axis=1)
    cls_ids = tf.argmax(cls_target[0], axis=1)
    score_map = cls_scores * ctr_target[0, :, 0]

    valid_indices = tf.where(score_map > score_threshold)[:, 0]

    valid_scores = tf.gather(score_map, valid_indices)
    valid_cls_ids = tf.gather(cls_ids, valid_indices)
    valid_centers = tf.gather(centers, valid_indices)
    valid_ltrb = tf.gather(reg_target[0], valid_indices)

    decoded_boxes = ltrb2boxes(valid_centers, valid_ltrb)

    nms_indices = tf.image.non_max_suppression(decoded_boxes,
                                               valid_scores, 
                                               max_output_size=300)
    boxes = tf.gather(decoded_boxes, nms_indices)
    scores = tf.gather(valid_scores, nms_indices)
    ids = tf.gather(valid_cls_ids, nms_indices)
    return boxes, scores, ids

In [None]:
i = 0
image = tf.io.read_file(images[i])
image = tf.image.decode_image(image)
image = tf.image.resize(image, [H, W])
image_disp = np.uint8(image.numpy().copy())
image = image[:, :, ::-1] - tf.constant([103.939, 116.779, 123.68])
image = image[None, ...]

In [None]:
logits = fcos.model(image, training=False)
boxes, scores, ids = decode_predictions(logits,
                       score_threshold=0.01,
                       centers=fcos._centers)
boxes = boxes.numpy()
ii = draw_boxes_cv2(image_disp, boxes, H, W)
imshow(ii)