In [None]:
from glob import glob
import logging
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from ssd.common.box_utils import rescale_boxes
from ssd.common.config import load_config
from ssd.common.viz_utils import visualize_detections
from ssd.models.ssd_model import SSDModel

logger = tf.get_logger()
logger.setLevel(logging.INFO)

logger.info('version: {}'.format(tf.__version__))

In [None]:
config = load_config('cfg/resnet50_2.yaml')

In [None]:
logger.info('Building model')
model = SSDModel(config)

In [None]:
latest_checkpoint = 'model_files/coco/resnet50_2/checkpoints/_COCO__ssd_weights_epoch_741'

model.load_weights(latest_checkpoint)
logger.info('Initialized model weights from {}'.format(latest_checkpoint))

In [None]:
def pad_to_square_fn(image):
    dims = tf.shape(image)
    image_height = dims[0]
    image_width = dims[1]

    side = tf.maximum(image_height, image_width)
    offset_x = 0
    offset_y = 0
    padded_image = tf.image.pad_to_bounding_box(image, offset_y, offset_x, side, side)
    return padded_image, side

def get_detections(image_path, input_shape, mean_pixel=(103.939, 116.779, 123.68)):
    image = tf.image.decode_image(tf.io.read_file(image_path), channels=3)
    image.set_shape([None, None, 3])
    image = tf.cast(image, dtype=tf.float32)
    padded_image, side = pad_to_square_fn(image)
    input_image = tf.image.resize(padded_image, [input_shape[0], input_shape[1]])
    input_image = input_image[:, :, ::-1] - tf.constant(mean_pixel)
    input_image = tf.expand_dims(input_image, axis=0)
    detections = model.predict(input_image)
    detections['boxes'] = rescale_boxes(detections['boxes'],
                                        [input_shape[0], input_shape[1]],
                                        [side, side])
    return image, detections

In [None]:
images = glob('dataset_downloads/coco/val2017/*')
print('Found {} images'.format(len(images)))

In [None]:
input_shape = [config['image_height'], config['image_width']]

for idx in tqdm(np.random.choice(range(len(images)), 10)):
    image, detections = get_detections(images[idx], input_shape)
    classes = [config['classes'][_id] for _id in detections['cls_ids']]
    ax = visualize_detections(image, detections['boxes'], classes, detections['scores'])
    ax.figure.savefig('assets/results/{}.png'.format(idx), bbox_inches='tight')
    plt.close()

In [None]:
rm assets/results/*