In [1]:
import cv2
import keras
from keras.applications.imagenet_utils import preprocess_input
from keras.backend.tensorflow_backend import set_session
from keras.models import Model
from keras.preprocessing import image
import matplotlib.pyplot as plt
import numpy as np
import pickle
from random import shuffle
#from scipy.misc import imread
from imageio import imread
from scipy.misc import imresize
from skimage.transform import resize
import tensorflow as tf

from ssd import SSD300
from ssd_training import MultiboxLoss
from ssd_utils import BBoxUtility

%matplotlib inline
plt.rcParams['figure.figsize'] = (8, 8)
plt.rcParams['image.interpolation'] = 'nearest'
path_prefix = '../yolo2keras/manhole/PNGImages/'

Using TensorFlow backend.
  return f(*args, **kwds)


In [2]:
# some constants
NUM_CLASSES = 2 #21 #4
input_shape = (300, 300, 3)
priors = pickle.load(open('prior_boxes_ssd300.pkl', 'rb'))
bbox_util = BBoxUtility(NUM_CLASSES, priors)

In [3]:
model = SSD300(input_shape, num_classes=NUM_CLASSES)
model.load_weights('manhole/checkpoints-300-3C/weights.30-0.18.hdf5', by_name=True)

In [4]:

gt = pickle.load(open('manhole/manhole.pkl', 'rb'))
keys = sorted(gt.keys())
num_train = int(round(0.8 * len(keys)))
train_keys = keys[:num_train]
val_keys = keys[num_train:]
num_val = len(val_keys)

In [5]:
inputs = []
images = {}
sorted_val_keys=sorted(val_keys)
for i, val_key in enumerate(sorted_val_keys):
    img_path = path_prefix + val_key
    img = image.load_img(img_path, target_size=(300, 300))
    img = image.img_to_array(img)
    inputs.append(img.copy())
    images[(i,val_key)] = imread(img_path)
inputs = preprocess_input(np.array(inputs))

In [6]:
preds = model.predict(inputs, batch_size=1, verbose=1)
results = bbox_util.detection_out(preds)



In [7]:
for (index, val_key) in images.keys():
    
    # Parse the outputs.
    if type(results[index]) is not np.ndarray:
        print("#{},{} ".format(index, val_key))
        continue
    det_label = results[index][:, 0]
    det_conf = results[index][:, 1]
    det_xmin = results[index][:, 2]
    det_ymin = results[index][:, 3]
    det_xmax = results[index][:, 4]
    det_ymax = results[index][:, 5]

    # Get detections with confidence higher than 0.6.
    top_indices = [m for m, conf in enumerate(det_conf) if conf >= 0.6]

    top_conf = det_conf[top_indices]
    top_label_indices = det_label[top_indices].tolist()
    top_xmin = det_xmin[top_indices]
    top_ymin = det_ymin[top_indices]
    top_xmax = det_xmax[top_indices]
    top_ymax = det_ymax[top_indices]

    colors = plt.cm.hsv(np.linspace(0, 1, 4)).tolist()
    img = images[(index, val_key)]
    plt.imshow(img / 255.)
    plt.gray()
    currentAxis = plt.gca()

    for n in range(top_conf.shape[0]):
        xmin = int(round(top_xmin[n] * img.shape[1]))
        ymin = int(round(top_ymin[n] * img.shape[0]))
        xmax = int(round(top_xmax[n] * img.shape[1]))
        ymax = int(round(top_ymax[n] * img.shape[0]))
        score = top_conf[n]
        label = int(top_label_indices[n])
        #label_name = voc_classes[label - 1]
        display_txt = '{:0.2f}, {}'.format(score, label)
        coords = (xmin, ymin), xmax-xmin+1, ymax-ymin+1
        color = colors[label]
        currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))
        currentAxis.text(xmin, ymin, display_txt, bbox={'facecolor':color, 'alpha':0.5})

    plt.savefig("./manhole/predict/predict_"+ val_key)
    plt.close()
    #plt.show()

#45,Img_0865.png 
