In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import glob

from ssd_model import SSD300, SSD512
from ssd_utils import PriorUtil
from utils.model import load_weights

### Data

In [None]:
# MS COCO
from data_coco import GTUtility
gt_util = GTUtility('data/COCO/', validation=True)
#gt_util = gt_util.convert_to_voc()
#print(gt_util)

### Model

In [None]:
# SDD300
# model = SSD300(num_classes=gt_util.num_classes)
# weights_path = './models/ssd300_voc_weights_fixed.hdf5'; confidence_threshold = 0.35

In [None]:
# SSD512
model = SSD512(num_classes=gt_util.num_classes)
weights_path = './models/ssd512_voc_weights_fixed.hdf5'; confidence_threshold = 0.7
#weights_path = './models/ssd512_coco_weights_fixed.hdf5'; confidence_threshold = 0.7

In [None]:
load_weights(model, weights_path)
prior_util = PriorUtil(model)

### Predict

In [None]:
import random

_, inputs, images, data = gt_util.sample_random_batch(batch_size=2, input_size=model.image_size, seed=random.randint(1, 10000))

# plot ground truth
for i in range(len(images)):
    # break
    plt.figure(figsize=[8]*2)
    plt.imshow(images[i])
    gt_util.plot_gt(data[i])
    plt.show()

In [None]:
# plot prior boxes
for m in prior_util.prior_maps:
    break
    plt.figure(figsize=[8]*2)
    plt.imshow(images[0])
    m.plot_locations()
    m.plot_boxes([0, 10, 100])
    plt.show()

In [None]:
preds = model.predict(inputs, batch_size=1, verbose=1)

In [None]:
checkdir = os.path.dirname(weights_path)

for fl in glob.glob('%s/result_*' % (checkdir,)):
    #os.remove(fl)
    pass

for i in range(2):
#for i in range(len(preds)):
    plt.figure(figsize=[8]*2)
    plt.imshow(images[i])
    res = prior_util.decode(preds[i], confidence_threshold=confidence_threshold, fast_nms=True)
    prior_util.plot_results(res, classes=gt_util.classes, show_labels=True, gt_data=data[i])
    #prior_util.plot_results(res, classes=gt_util.classes, show_labels=True, gt_data=None)
    plt.axis('off')
    #plt.savefig('%s/result_%03d.jpg' % (checkdir, i))
    plt.show()

### Real world images

In [None]:
import cv2
from ssd_data import preprocess

inputs = []
images = []

img_paths = glob.glob('./data/images/*.jpg')

for img_path in img_paths:
    img = cv2.imread(img_path)
    inputs.append(preprocess(img, model.image_size))
    h, w = model.image_size
    img = cv2.resize(img, (w,h), cv2.INTER_LINEAR).astype('float32')
    img = img[:, :, (2,1,0)] # BGR to RGB
    img /= 255
    images.append(img)
    
inputs = np.asarray(inputs)

preds = model.predict(inputs, batch_size=1, verbose=1)

In [None]:
link_threshold = 0.3
for i in range(len(images)):
    print(img_paths[i])
    plt.figure(figsize=[8]*2, frameon=True)
    plt.imshow(images[i])
    res = prior_util.decode(preds[i], confidence_threshold=link_threshold)
    prior_util.plot_results(res, classes=gt_util.classes)
    plt.axis('off')
    plt.show()