## Load necessary modules

In [1]:
# show images inline
%matplotlib inline

# automatically reload modules when they have changed
%load_ext autoreload
%autoreload 2

# import keras
import keras

# import keras_retinanet
from keras_retinanet import models
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
from keras_retinanet.utils.visualization import draw_box, draw_caption
from keras_retinanet.utils.colors import label_color

# import miscellaneous modules
import matplotlib.pyplot as plt
import cv2
import os
import numpy as np
import time
import csv
# set tf backend to allow memory to grow, instead of claiming everything
import tensorflow as tf

def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    return tf.Session(config=config)

# use this environment flag to change which GPU to use
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# set the modified tf session as backend in keras
keras.backend.tensorflow_backend.set_session(get_session())

Using TensorFlow backend.


## Load RetinaNet model

In [2]:
# adjust this to point to your downloaded/trained model
# models can be downloaded here: https://github.com/fizyr/keras-retinanet/releases
model_path = os.path.join('snapshots', 'version8_resplit_test_train', 'resnet50_csv_12_inference.h5')

# load retinanet model
model = models.load_model(model_path, backbone_name='resnet50')

# if the model is not converted to an inference model, use the line below
# see: https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model
#model = models.convert_model(model)

#print(model.summary())

# load label to names mapping for visualization purposes
labels_to_names = {0: 'Biker', 1: 'Car', 2: 'Bus', 3: 'Cart', 4: 'Skater', 5: 'Pedestrian'}
# labels_to_names = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}



### 1. Load CSV

In [3]:
annotations_file = 'test_annotations.csv'
cwd = os.getcwd()

image_names = []

with open (os.path.join(cwd, annotations_file)) as csvfile:
    readCSV = csv.reader(csvfile, delimiter='\n')
    for row in readCSV:
        vals = row[0].split(',')
        this_filepath = vals[0]
        image_names.append(this_filepath)
        
## remove duplicates by taking a set
image_names = list(set(image_names))
print(len(image_names))

663


In [4]:
print(image_names[:2])

['imgs/test/little_video1_7890.jpg', 'imgs/test/bookstore_video1_12390.jpg']


In [5]:
score_threshold_retinanet = 0.4
acceptable_box_overlap = 0.3

### 2. Extract Ground Truth Annotations

In [6]:
def get_gt_annotations(filepath):
    gt_ann = []
    with open (os.path.join(cwd, annotations_file)) as csvfile:
        readCSV = csv.reader(csvfile, delimiter='\n')
        for row in readCSV:
            vals = row[0].split(',')
            this_filepath = vals[0]
            
            if this_filepath == filepath:
                record = []
                record.append(int(vals[1]))
                record.append(int(vals[2]))
                record.append(int(vals[3]))
                record.append(int(vals[4]))
                record.append(vals[5])
                
                gt_ann.append(record)
    return gt_ann
    

### 3. IOU Calculations

In [7]:
def bb_intersection_over_union(boxA, boxB):
	# determine the (x, y)-coordinates of the intersection rectangle
	xA = max(int(boxA[0]), int(boxB[0]))
	yA = max(int(boxA[1]), int(boxB[1]))
	xB = min(int(boxA[2]), int(boxB[2]))
	yB = min(int(boxA[3]), int(boxB[3]))
 
	# compute the area of intersection rectangle
	interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
 
	# compute the area of both the prediction and ground-truth
	# rectangles
	boxAArea = (int(boxA[2]) - int(boxA[0]) + 1) * (int(boxA[3]) - int(boxA[1]) + 1)
	boxBArea = (int(boxB[2]) - int(boxB[0]) + 1) * (int(boxB[3]) - int(boxB[1]) + 1)
 
	# compute the intersection over union by taking the intersection
	# area and dividing it by the sum of prediction + ground-truth
	# areas - the interesection area
	iou = interArea / float(boxAArea + boxBArea - interArea)
 
	# return the intersection over union value
	return iou

In [8]:
def highest_iou(predicted_box, gt_ann):
    
    highest_iou = 0
    best_match = []
    for each_gt in gt_ann:
        this_iou = bb_intersection_over_union(predicted_box[:-1], each_gt[:-1])
#         print(" For this box is: ", this_iou)
        if this_iou > highest_iou:
            highest_iou = this_iou
            best_match = each_gt
            
    return highest_iou, best_match
        

In [9]:
def accuracy(gt_ann, predicted_ann):
    true_positive = []
    class_mismatch = []
    false_positive = []
    false_negative = []
    
    predicted_ann_copy = predicted_ann.copy()
    gt_ann_copy = gt_ann.copy()
    
    for each_pred in predicted_ann_copy:
        # 1. Calculate the highest_iou with any gt_box
        best_iou, best_gt_match = highest_iou(each_pred, gt_ann_copy)
#         print("Best IOU is: ", each_pred, best_gt_match, best_iou)
        
        ## If this box has a match
        if best_iou >= acceptable_box_overlap:
            ## if class label matches
            if best_gt_match[-1] == each_pred[-1]:
                ## This is a true positive
                true_positive.append(each_pred)
                ## remove this from predicted ann and gt
#                 print("Predicted Ann before:", predicted_ann)
                predicted_ann.remove(each_pred)
#                 print("Predicted Ann after:", predicted_ann)
                if best_gt_match in gt_ann:
                    gt_ann.remove(best_gt_match)
            elif best_gt_match[-1] != each_pred[-1]:
                ## this is a class mismatch
                class_mismatch.append(each_pred)
                ## remove this box
                predicted_ann.remove(each_pred)
                if best_gt_match in gt_ann:
                    gt_ann.remove(best_gt_match)
        
    ## If IOU is less than 0.5, leave as is

    ## Any predicted box is now false positive
    for remain_pred in predicted_ann:
        false_positive.append(remain_pred)

    ## Any remaining gt box is false negative
    for remain_gt in gt_ann:
        if remain_gt[-1] in ['Biker', 'Pedestrian', 'Car', 'Bus']:
            false_negative.append(remain_gt)

    return true_positive, class_mismatch, false_positive, false_negative       

### 3. Draw Predicted and GT Annotations

In [10]:
def run_detection_image(filepath):
    image = read_image_bgr(filepath)

    # copy to draw on
    draw = image.copy()
    draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB)

    # preprocess image for network
    image = preprocess_image(image)
    image, scale = resize_image(image)

    # process image
    start = time.time()
    boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))
    print("processing time: ", time.time() - start)

    # correct for image scale
    boxes /= scale
    
    gt_ann = get_gt_annotations(filepath)
    num_gt_ann = len(gt_ann)
    
    predicted_ann = []

    # visualize detections
    for box, score, label in zip(boxes[0], scores[0], labels[0]):
        # scores are sorted so we can break
        if score < score_threshold_retinanet:
            break
        record = []
        record.append(int(box[0]))
        record.append(int(box[1]))
        record.append(int(box[2]))
        record.append(int(box[3]))
        record.append(labels_to_names[label])
        print("Record is: ", record)
        predicted_ann.append(record)
    
    
#     print("predicted ann are: ", predicted_ann)
    ## Calculate boxes
    true_positive, class_mismatch, false_positive, false_negative = accuracy(gt_ann, predicted_ann)
    print(len(true_positive), len(class_mismatch), len(false_positive), len(false_negative))
    
    font = cv2.FONT_HERSHEY_SIMPLEX
    ## Draw theses on the image
    ## Draw true positive in green
    if len(true_positive) > 0:
        for each_true in true_positive:
            cv2.rectangle(draw,(each_true[0],each_true[1]),(each_true[2],each_true[3]),(0,255,0),3) #green
            cv2.putText(draw, each_true[-1], (each_true[0]-2, each_true[1]-2),font, 0.5,
                        (0,0,0),1,cv2.LINE_AA) # text in black
        
    ## Draw class mismatch in light blue   
    if len(class_mismatch) > 0:
        for each_mismatch in class_mismatch:
            cv2.rectangle(draw,(each_mismatch[0],each_mismatch[1]),(each_mismatch[2],each_mismatch[3]),(255,255,0),3) #green
            cv2.putText(draw, each_mismatch[-1], (each_mismatch[0]-2, each_mismatch[1]-2),font, 0.5,
                        (0,0,0),1,cv2.LINE_AA) # text in black
        
    
     ## Draw false positive  in blue 
    if len(false_positive) > 0:
        for each_fp in false_positive:
            cv2.rectangle(draw,(each_fp[0],each_fp[1]),(each_fp[2],each_fp[3]),(255,0,0),3) #green
            cv2.putText(draw, each_fp[-1], (each_fp[0]-2, each_fp[1]-2),font, 0.5,
                        (0,0,0),1,cv2.LINE_AA) # text in black
    
    
     ## Draw false negative in red 
    if len(false_negative) > 0:
        for each_fn in false_negative:
            cv2.rectangle(draw,(each_fn[0],each_fn[1]),(each_fn[2],each_fn[3]),(0,0,255),3) #green
            cv2.putText(draw, each_fn[-1], (each_fn[0]-2, each_fn[1]-2),font, 0.5,
                        (0,0,0),1,cv2.LINE_AA) # text in black
            
            
    ## Add key to the image
    cv2.putText(draw, "True Positive", (1200, 20),font, 0.8, (0,255,0),1,cv2.LINE_AA) 
    cv2.putText(draw, "Class Mismatch", (1200, 50),font, 0.8, (255,255,0),1,cv2.LINE_AA)
    cv2.putText(draw, "False Positive", (1200, 80),font, 0.8, (255,0,0),1,cv2.LINE_AA)
    cv2.putText(draw, "False Negtaive", (1200, 110),font, 0.8, (0,0,2550),1,cv2.LINE_AA)
    
    ### Save this image
    
    file, ext = os.path.splitext(filepath)
    image_name = file.split('/')[-1] + ext
    output_path = os.path.join('examples/results_test/', image_name)
    
    draw_conv = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB)
    cv2.imwrite(output_path, draw_conv)
    
    return len(true_positive), len(class_mismatch), len(false_positive), len(false_negative), num_gt_ann

#         color = label_color(label)

#         b = box.astype(int)
#         draw_box(draw, b, color=color)

#         caption = "{} {:.3f}".format(labels_to_names[label], score)
#         draw_caption(draw, b, caption)

#     plt.figure(figsize=(15, 15))
#     plt.axis('off')
#     plt.imshow(draw)
#     plt.show()
    


In [11]:
total_true_positive = 0
total_class_mismatch = 0
total_false_positive = 0
total_false_negative = 0
total_gt = 0

In [12]:
for filepath in image_names:
    num_tp, num_cm , num_fp, num_fn, num_gt = run_detection_image(filepath)
    print("Ground Truth: ", num_gt)
    total_true_positive += num_tp
    total_class_mismatch += num_cm
    total_false_positive += num_fp
    total_false_negative += num_fn
    total_gt += num_gt

processing time:  1.293830394744873
Record is:  [804, 829, 860, 914, 'Biker']
Record is:  [1057, 1720, 1125, 1817, 'Biker']
Record is:  [1041, 1706, 1115, 1814, 'Pedestrian']
Record is:  [974, 872, 1039, 962, 'Biker']
Record is:  [981, 893, 1014, 959, 'Biker']
Record is:  [929, 1767, 1001, 1851, 'Biker']
Record is:  [72, 877, 124, 948, 'Biker']
4 1 2 1
Ground Truth:  5
processing time:  0.31607723236083984
Record is:  [1375, 202, 1409, 246, 'Pedestrian']
Record is:  [992, 561, 1035, 597, 'Pedestrian']
Record is:  [642, 347, 684, 398, 'Pedestrian']
Record is:  [740, 266, 794, 316, 'Pedestrian']
Record is:  [306, 792, 349, 846, 'Biker']
Record is:  [1309, 945, 1356, 992, 'Biker']
Record is:  [1232, 689, 1277, 739, 'Pedestrian']
Record is:  [294, 869, 323, 916, 'Pedestrian']
Record is:  [693, 688, 725, 743, 'Pedestrian']
5 1 3 8
Ground Truth:  14
processing time:  0.05529141426086426
Record is:  [384, 859, 454, 950, 'Biker']
Record is:  [901, 827, 962, 918, 'Biker']
2 0 0 0
Ground Truth: 

Ground Truth:  3
processing time:  0.05207991600036621
Record is:  [1377, 199, 1411, 249, 'Pedestrian']
Record is:  [1166, 333, 1210, 385, 'Pedestrian']
Record is:  [965, 738, 1022, 790, 'Pedestrian']
Record is:  [702, 336, 744, 390, 'Pedestrian']
Record is:  [310, 799, 346, 848, 'Biker']
Record is:  [315, 103, 352, 155, 'Biker']
Record is:  [748, 274, 793, 317, 'Pedestrian']
Record is:  [1337, 927, 1389, 988, 'Biker']
Record is:  [53, 383, 89, 427, 'Pedestrian']
Record is:  [966, 739, 1024, 789, 'Biker']
8 2 0 6
Ground Truth:  15
processing time:  0.05224347114562988
Record is:  [1345, 965, 1377, 1000, 'Pedestrian']
Record is:  [1338, 922, 1382, 960, 'Biker']
Record is:  [700, 562, 741, 612, 'Pedestrian']
Record is:  [400, 753, 435, 792, 'Pedestrian']
Record is:  [1314, 947, 1350, 986, 'Biker']
Record is:  [706, 348, 751, 395, 'Pedestrian']
Record is:  [979, 927, 1028, 974, 'Pedestrian']
Record is:  [1373, 204, 1404, 246, 'Pedestrian']
Record is:  [691, 694, 728, 739, 'Pedestrian']
Re

Ground Truth:  3
processing time:  0.05251336097717285
Record is:  [694, 335, 736, 387, 'Pedestrian']
Record is:  [1009, 915, 1047, 957, 'Pedestrian']
Record is:  [273, 575, 303, 623, 'Pedestrian']
Record is:  [10, 951, 38, 998, 'Pedestrian']
Record is:  [681, 974, 726, 1033, 'Biker']
Record is:  [204, 698, 235, 753, 'Pedestrian']
Record is:  [98, 1019, 153, 1072, 'Pedestrian']
Record is:  [104, 818, 142, 872, 'Pedestrian']
5 2 1 8
Ground Truth:  15
processing time:  0.05601024627685547
Record is:  [279, 680, 336, 769, 'Pedestrian']
Record is:  [1216, 823, 1270, 918, 'Pedestrian']
Record is:  [289, 760, 344, 844, 'Pedestrian']
3 0 0 0
Ground Truth:  3
processing time:  0.05608534812927246
Record is:  [936, 854, 997, 923, 'Biker']
Record is:  [438, 984, 510, 1082, 'Biker']
Record is:  [0, 1143, 30, 1223, 'Pedestrian']
Record is:  [715, 849, 768, 920, 'Biker']
4 0 0 0
Ground Truth:  4
processing time:  0.05590200424194336
Record is:  [1047, 843, 1097, 922, 'Biker']
Record is:  [85, 979, 

processing time:  0.06256246566772461
Record is:  [1320, 144, 1378, 185, 'Pedestrian']
Record is:  [1273, 150, 1329, 192, 'Pedestrian']
Record is:  [1432, 119, 1491, 162, 'Pedestrian']
Record is:  [1329, 108, 1374, 153, 'Pedestrian']
Record is:  [1407, 145, 1462, 191, 'Pedestrian']
Record is:  [1265, 47, 1335, 101, 'Pedestrian']
Record is:  [400, 351, 488, 417, 'Pedestrian']
Record is:  [1274, 115, 1320, 160, 'Pedestrian']
Record is:  [578, 339, 675, 409, 'Pedestrian']
Record is:  [578, 794, 640, 838, 'Biker']
Record is:  [1355, 936, 1472, 1010, 'Biker']
Record is:  [1274, 131, 1325, 178, 'Pedestrian']
12 0 0 0
Ground Truth:  11
processing time:  0.06735086441040039
Record is:  [1171, 809, 1239, 934, 'Biker']
Record is:  [504, 837, 599, 935, 'Biker']
Record is:  [85, 827, 167, 946, 'Biker']
3 0 0 0
Ground Truth:  3
processing time:  0.06448650360107422
Record is:  [1376, 204, 1413, 249, 'Pedestrian']
Record is:  [661, 327, 700, 377, 'Pedestrian']
Record is:  [1137, 902, 1173, 943, 'Ped

processing time:  0.05893588066101074
Record is:  [1289, 149, 1328, 188, 'Pedestrian']
Record is:  [1326, 148, 1380, 185, 'Pedestrian']
Record is:  [1427, 114, 1494, 165, 'Pedestrian']
Record is:  [1333, 111, 1377, 149, 'Pedestrian']
Record is:  [1282, 83, 1331, 124, 'Pedestrian']
Record is:  [1286, 116, 1339, 156, 'Pedestrian']
Record is:  [1407, 139, 1474, 197, 'Pedestrian']
Record is:  [371, 770, 441, 816, 'Pedestrian']
Record is:  [1270, 51, 1329, 94, 'Pedestrian']
Record is:  [580, 352, 648, 418, 'Pedestrian']
Record is:  [382, 821, 445, 868, 'Pedestrian']
Record is:  [1557, 893, 1642, 980, 'Biker']
Record is:  [1239, 814, 1323, 862, 'Pedestrian']
Record is:  [417, 367, 479, 416, 'Pedestrian']
Record is:  [926, 819, 1016, 882, 'Biker']
Record is:  [1459, 738, 1510, 776, 'Pedestrian']
Record is:  [1554, 894, 1642, 986, 'Pedestrian']
15 1 1 0
Ground Truth:  15
processing time:  0.05468392372131348
Record is:  [687, 348, 735, 400, 'Pedestrian']
Record is:  [1155, 580, 1190, 622, 'Ped

Ground Truth:  14
processing time:  0.0583796501159668
Record is:  [557, 911, 627, 985, 'Biker']
Record is:  [301, 1160, 356, 1249, 'Pedestrian']
Record is:  [559, 729, 614, 808, 'Pedestrian']
Record is:  [99, 1063, 150, 1144, 'Biker']
Record is:  [576, 708, 657, 810, 'Pedestrian']
5 0 0 1
Ground Truth:  5
processing time:  0.059877634048461914
Record is:  [689, 1330, 771, 1429, 'Biker']
Record is:  [7, 1411, 79, 1524, 'Pedestrian']
Record is:  [1072, 1191, 1141, 1290, 'Biker']
Record is:  [117, 1101, 198, 1216, 'Biker']
Record is:  [330, 1576, 374, 1657, 'Pedestrian']
Record is:  [633, 1784, 705, 1881, 'Biker']
Record is:  [1018, 1477, 1083, 1585, 'Biker']
Record is:  [304, 1623, 352, 1709, 'Pedestrian']
Record is:  [131, 612, 197, 711, 'Pedestrian']
Record is:  [392, 587, 436, 680, 'Pedestrian']
Record is:  [663, 136, 719, 259, 'Biker']
Record is:  [200, 954, 257, 1055, 'Pedestrian']
12 0 0 3
Ground Truth:  15
processing time:  0.0943000316619873
Record is:  [1909, 620, 1935, 644, 'P

processing time:  0.3062257766723633
Record is:  [1153, 1122, 1201, 1193, 'Pedestrian']
Record is:  [776, 254, 824, 347, 'Biker']
Record is:  [487, 471, 529, 570, 'Biker']
Record is:  [546, 1817, 589, 1902, 'Biker']
Record is:  [609, 1179, 654, 1268, 'Biker']
Record is:  [384, 665, 439, 755, 'Biker']
Record is:  [769, 183, 806, 255, 'Biker']
Record is:  [743, 47, 791, 134, 'Biker']
Record is:  [772, 1025, 829, 1094, 'Biker']
Record is:  [613, 1467, 647, 1536, 'Biker']
Record is:  [795, 889, 843, 987, 'Biker']
Record is:  [776, 1065, 814, 1138, 'Biker']
Record is:  [522, 566, 575, 662, 'Biker']
Record is:  [684, 1113, 740, 1184, 'Biker']
Record is:  [744, 470, 790, 548, 'Biker']
Record is:  [744, 470, 790, 550, 'Cart']
Record is:  [702, 1840, 736, 1909, 'Pedestrian']
Record is:  [720, 168, 757, 273, 'Biker']
Record is:  [315, 1054, 356, 1109, 'Biker']
Record is:  [660, 1827, 697, 1911, 'Pedestrian']
Record is:  [717, 694, 754, 770, 'Cart']
Record is:  [500, 756, 548, 827, 'Biker']
Recor

KeyboardInterrupt: 

In [98]:
print(total_true_positive, total_class_mismatch, total_false_positive, total_false_negative, total_gt)

5198 551 812 1531 6867


gt_ann = get_gt_annotations(filepath)
print(gt_ann)
im = cv2.imread(filepath)
for each_ann in gt_ann:
    cv2.rectangle(im, (each_ann[0],each_ann[1]),(each_ann[2],each_ann[3]),(0,0,255),3)
cv2.imwrite('examples/results_test/pred_gt.jpg', im)

predicted_ann = run_detection_image(filepath)
print(predicted_ann)

true_positive, class_mismatch, false_positive, false_negative = accuracy(gt_ann, predicted_ann)

In [99]:
# 4763 687 2192 1067 5854

precision = total_true_positive/(total_true_positive+total_false_positive+total_class_mismatch)
recall = total_true_positive/(total_true_positive+ total_false_negative)
f1_score = 2*(precision * recall)/(precision + recall)

print(precision, recall, f1_score)

0.7922572778539857 0.7724773368999851 0.782242287434161


In [100]:
# 0.8605186114596404 0.6314255025318398 0.7283830427471458 at 0.5
# 0.6608789366542929 0.8296752856283824 0.7357195227621142 at 0.3
#0.7766865239091933 0.7453646477132262 0.7607033036347867 at 0.4