In [None]:
import logging
import os
from pprint import pprint
import shutil
from tqdm import tqdm

import tensorflow as tf

from ssd.common.callbacks import CallbackBuilder
from ssd.common.distribute import get_strategy
from ssd.common.config import load_config
from ssd.common.viz_utils import draw_boxes_cv2, imshow
from ssd.data.dataset_builder import DatasetBuilder
from ssd.losses.multibox_loss import MultiBoxLoss
from ssd.models.ssd_model import SSDModel

logger = tf.get_logger()
logger.setLevel(logging.INFO)

logger.info('version: {}'.format(tf.__version__))

In [None]:
# !python ssd/scripts/calculate_feature_shapes.py --image_height 640 --image_width 640 --num_feature_maps 7
# !python ssd/scripts/calculate_scales.py -n 6 --s_first 0.1 --smin 0.2 --smax 0.9
!python check_matching.py ssd/cfg/coco_base.yaml

In [None]:
config = load_config('ssd/cfg/coco_base.yaml')

dtype = 'float32'

policy = tf.keras.mixed_precision.experimental.Policy(dtype)
tf.keras.mixed_precision.experimental.set_policy(policy)

logger.info('Compute dtype: {}'.format(policy.compute_dtype))
logger.info('Variable dtype: {}'.format(policy.variable_dtype))

strategy = get_strategy(config)

epochs = config['epochs']

lr_values = list(config['lr_values'])
if config['scale_lr']:
    for i in range(len(lr_values)):
        lr_values[i] *= strategy.num_replicas_in_sync
config['lr_values'] = lr_values

batch_size = config['batch_size']
batch_size = batch_size if not config['scale_batch_size'] else batch_size * strategy.num_replicas_in_sync
config['batch_size'] = batch_size

train_steps = config['train_images'] // config['batch_size']
val_steps = config['val_images'] // config['batch_size']

config

In [None]:
if config['clear_previous_runs']:
    try:
        shutil.rmtree(os.path.join(config['model_dir']))
        logger.info('Cleared existing model files\n')
    except FileNotFoundError:
        logger.warning('mode_dir not found!')
        
with strategy.scope():
    train_dataset = DatasetBuilder('train', config)
    val_dataset = DatasetBuilder('val', config)

    loss_fn = MultiBoxLoss(config)
    lr_sched = tf.optimizers.schedules.PiecewiseConstantDecay(config['lr_boundaries'], config['lr_values'])
    optimizer = tf.optimizers.SGD(lr_sched, momentum=config['optimizer_momentum'], nesterov=config['nestrov'])
    callbacks_list = CallbackBuilder('_COCO_', config).get_callbacks()

    model = SSDModel(config)
    model.compile(loss_fn=loss_fn, optimizer=optimizer)
    if config['resume_training']:
        latest_checkpoint = tf.train.latest_checkpoint(os.path.join(config['model_dir'], 'checkpoints'))
        if latest_checkpoint is not None:
            logger.info('Loading weights from {}'.format(latest_checkpoint))
            model.load_weights(latest_checkpoint)
        else:
            logger.warning('No weights found, training from scratch')

In [None]:
model.fit(train_dataset.dataset,
          epochs=epochs,
          steps_per_epoch=train_steps,
          validation_data=val_dataset.dataset,
          validation_steps=val_steps,
          callbacks=callbacks_list)

with strategy.scope():
    save_path = os.path.join(config['model_dir'], 'final_weights', 'ssd_weights')
    logger.info('Saving final weights at in {}'.format(save_path))
    model.save_weights(save_path)

In [None]:
with strategy.scope():
    best_weights = tf.train.latest_checkpoint(os.path.join(config['model_dir'], 'best_weights'))
    model.load_weights(best_weights)
    logger.info('Loaded weights from {}'.format(best_weights))

In [None]:
for images, _ in val_dataset.dataset.take(1):
    for i in range(images.shape[0]):
        if i == 20: break
        image = images[i]
        detections = model.get_detections(image[None, ...])
        image = (image + tf.constant([103.939, 116.779, 123.68]))[:, :, ::-1]
        categories = [config['classes'][cls_id] for cls_id in detections['cls_ids'].numpy()]
        image = draw_boxes_cv2(image, detections['boxes'], categories)
        imshow(image)