In [None]:
from glob import glob
import logging
import os
from pprint import pprint
import shutil
from time import time
from tqdm import tqdm

import tensorflow as tf

from ssd.common.callbacks import CallbackBuilder
from ssd.common.distribute import get_strategy
from ssd.common.config import load_config
from ssd.common.viz_utils import draw_boxes_cv2, imshow, visualize_detections
from ssd.data.dataset_builder import DatasetBuilder
from ssd.losses.multibox_loss import MultiBoxLoss
from ssd.models.ssd_model import SSDModel

logger = tf.get_logger()
logger.setLevel(logging.INFO)

logger.info('version: {}'.format(tf.__version__))

In [None]:
# !python3 ssd/scripts/calculate_feature_shapes.py --image_height 512 --image_width 512 --num_feature_maps 7
# !python3 ssd/scripts/calculate_scales.py -n 7 --s_first 0.04 --smin 0.1 --smax 0.9
# !python3 check_matching.py ssd/cfg/coco_resnet_50_512x512-16.yaml

In [None]:
config = load_config('ssd/cfg/coco_resnet_50_512x512-8.yaml')
# config['use_mixed_precision'] = False
# config['use_tpu'] = False
config['augment_val_dataset'] = False
config['resume_training'] = False

if config['use_mixed_precision']:
    if config['use_tpu']:
        dtype = 'mixed_bfloat16'
    elif config['use_gpu']:
#         dtype = 'mixed_float16' # todo: implement loss scaling
        dtype = 'float32'
else:
    dtype = 'float32'
        
policy = tf.keras.mixed_precision.experimental.Policy(dtype)
tf.keras.mixed_precision.experimental.set_policy(policy)

logger.info('\nCompute dtype: {}'.format(policy.compute_dtype))
logger.info('Variable dtype: {}'.format(policy.variable_dtype))

strategy = get_strategy(config)

epochs = config['epochs']

lr_values = list(config['lr_values'])
if config['scale_lr']:
    for i in range(len(lr_values)):
        lr_values[i] *= strategy.num_replicas_in_sync
config['lr_values'] = lr_values

batch_size = config['batch_size']
batch_size = batch_size if not config['scale_batch_size'] else batch_size * strategy.num_replicas_in_sync
config['batch_size'] = batch_size

train_steps = config['train_images'] // config['batch_size']
val_steps = config['val_images'] // config['batch_size']

print('\n')
pprint(config, width=120, compact=True)

In [None]:
if config['clear_previous_runs']:
    if config['use_tpu']:
        logger.warning('Skipping GCS Bucket')
    else:
        try:
            shutil.rmtree(os.path.join(config['model_dir']))
            logger.info('Cleared existing model files\n')
        except FileNotFoundError:
            logger.warning('model_dir not found!')

In [None]:
with strategy.scope():
    train_dataset = DatasetBuilder('train', config)
    val_dataset = DatasetBuilder('val', config)

    loss_fn = MultiBoxLoss(config)
    lr_sched = tf.optimizers.schedules.PiecewiseConstantDecay(config['lr_boundaries'], config['lr_values'])
    optimizer = tf.optimizers.SGD(lr_sched, momentum=config['optimizer_momentum'], nesterov=config['nestrov'])
    callbacks_list = CallbackBuilder('_COCO_', config).get_callbacks()

    model = SSDModel(config)
    model.compile(loss_fn=loss_fn, optimizer=optimizer)
    if config['resume_training']:
        latest_checkpoint = tf.train.latest_checkpoint(os.path.join(config['model_dir'], 'checkpoints'))
        if latest_checkpoint is not None:
            logger.info('Loading weights from {}'.format(latest_checkpoint))
            model.load_weights(latest_checkpoint)
        else:
            logger.warning('No weights found, training from scratch')

In [None]:
# model.fit(train_dataset.dataset,
#           epochs=epochs,
#           steps_per_epoch=train_steps,
#           validation_data=val_dataset.dataset,
#           validation_steps=val_steps,
#           callbacks=callbacks_list)

# with strategy.scope():
#     save_path = os.path.join(config['model_dir'], 'final_weights', 'ssd_weights')
#     logger.info('Saving final weights at in {}'.format(save_path))
#     model.save_weights(save_path)

In [None]:
with strategy.scope():
    latest_checkpoint = tf.train.latest_checkpoint(os.path.join(config['model_dir'] , 'best_weights'))
    logger.info('Loading weights from {}'.format(latest_checkpoint))
    model.load_weights(latest_checkpoint)

In [None]:
for images, _ in val_dataset.dataset.take(5):
    for i in tqdm(range(images.shape[0])):
        image = tf.cast(images[i], dtype=policy.compute_dtype)
        detections = model.predict(image[None, ...])
        image = (image + tf.constant([103.939, 116.779, 123.68]))[:, :, ::-1]
        categories = [config['classes'][cls_id] for cls_id in detections['cls_ids']]
        ax = visualize_detections(image, detections['boxes'], categories, detections['scores'])

In [None]:
image_path = 'assets/000_G96XV-640x400.jpg'

image = tf.io.read_file(image_path)
image = tf.image.decode_image(image)
image = tf.cast(image, tf.float32)

# image = random_pad_to_square(image)
image = tf.image.resize(image, [config['image_height'], config['image_width']])

image_preprocessed = image[:, :, ::-1] - tf.constant([103.939, 116.779, 123.68])

s = time()
detections = model.predict(image_preprocessed[None, ...])
e = time()
logger.info('Inference time: {:.3f}'.format(e - s))
categories = [config['classes'][cls_id] for cls_id in detections['cls_ids']]
image = draw_boxes_cv2(image, detections['boxes'], categories, thickness=1)
imshow(image / 255, (16, 16))