In [1]:
import tensorflow as tf
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
path = "../../../data/traffic_light_sim_captured/"
files = os.listdir(path)
states = []
images = []
for f in files:
    states.append(f.split("_")[1])
    images.append(cv2.imread(os.path.join(path, f))[...,::-1])

images = np.array(images)

In [3]:
model_file = "../../../data/traffic-light-models/rfcn_resnet101_coco_11_06_2017/frozen_inference_graph.pb"

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(model_file, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

sess = tf.Session(graph=detection_graph)

In [4]:
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')

In [13]:
def prepare_for_class(image, boxes):
    shape = image.shape
    (left, right, top, bottom) = (boxes[0, 1] * shape[2], boxes[0, 3] * shape[2],
                                    boxes[0, 0] * shape[1], boxes[0, 2] * shape[1])

    crop_height = int(bottom - top)
    crop_width = int(right - left)

    if 1.5*crop_width < crop_height < 3.5*crop_width:
        center = (int(left)+ int(right)) // 2
        if (center - (crop_height // 2) < 0):
            cropped = image[0, int(top): int(bottom), 0: crop_height , :]
        elif (center + (crop_height // 2) > shape[2]):
            cropped = image[0, int(top): int(bottom), shape[2] - crop_height: shape[2], :]
        else:
            cropped = image[0, int(top) : int(bottom), center - (crop_height // 2): center + (crop_height//2), :]
        resized = cv2.resize(cropped, (50, 50), interpolation = cv2.INTER_CUBIC)
        return resized[...,::-1]
    else:
        return None

In [None]:
save_path = "../../../data/simulator_traffic_lights_extracted/"
if not os.path.isdir(save_path):
    os.makedirs(save_path)
    
red = "../../../data/simulator_traffic_lights_extracted/red/"
if not os.path.isdir(red):
    os.makedirs(red)

yellow = "../../../data/simulator_traffic_lights_extracted/yellow/"
if not os.path.isdir(yellow):
    os.makedirs(yellow)

green = "../../../data/simulator_traffic_lights_extracted/green/"
if not os.path.isdir(green):
    os.makedirs(green)

count = 0
for i in range(len(images)):
    b, s, c, n_det = sess.run([boxes, scores, classes, num_detections], feed_dict={image_tensor: images[i:i+1]})
    tl_detections = [k for k in range(b.shape[1]) if (s[0, k] > 0.8 and c[0,k] == 10)]
    cropped = np.array([prepare_for_class(images[i:i+1], b[:, k, :]) for k in tl_detections if k is not None])
    for im in cropped:

        if int(states[i]) == 0:
            file_path = os.path.join(red, str(count) + ".png")
        elif int(states[i]) == 1:
            file_path = os.path.join(yellow, str(count) + ".png")
        elif int(states[i]) == 2:
            file_path = os.path.join(green, str(count) + ".png")
        else:
            continue
        count += 1
        cv2.imwrite(file_path, im)
