## Load necessary modules

In [2]:
# show images inline
%matplotlib inline

# automatically reload modules when they have changed
%load_ext autoreload
%autoreload 2

# import keras
import keras

# import keras_retinanet
from keras_retinanet import models
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
from keras_retinanet.utils.visualization import draw_box, draw_caption
from keras_retinanet.utils.colors import label_color

# import miscellaneous modules
import matplotlib.pyplot as plt
import cv2
import os
import numpy as np
import time
import csv
# set tf backend to allow memory to grow, instead of claiming everything
import tensorflow as tf

def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    return tf.Session(config=config)

# use this environment flag to change which GPU to use
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# set the modified tf session as backend in keras
keras.backend.tensorflow_backend.set_session(get_session())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Using TensorFlow backend.


## Load RetinaNet model

In [3]:

model_path = os.path.join('resnet50_csv_12_inference.h5')

# load retinanet model
model = models.load_model(model_path, backbone_name='resnet50')




# load label to names mapping for visualization purposes
labels_to_names = {0: 'Biker', 1: 'Car', 2: 'Bus', 3: 'Cart', 4: 'Skater', 5: 'Pedestrian'}
# labels_to_names = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}

tracking <tf.Variable 'Variable:0' shape=(12, 4) dtype=float32> anchors
tracking <tf.Variable 'Variable_1:0' shape=(12, 4) dtype=float32> anchors
tracking <tf.Variable 'Variable_2:0' shape=(12, 4) dtype=float32> anchors
tracking <tf.Variable 'Variable_3:0' shape=(12, 4) dtype=float32> anchors
tracking <tf.Variable 'Variable_4:0' shape=(12, 4) dtype=float32> anchors


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where




### 1. Load CSV

In [4]:
annotations_file = 'test_annotations.csv'
cwd = os.getcwd()

image_names = []

with open (os.path.join(cwd, annotations_file)) as csvfile:
    readCSV = csv.reader(csvfile, delimiter='\n')
    for row in readCSV:
        vals = row[0].split(',')
        this_filepath = vals[0]
        image_names.append(this_filepath)
        
## remove duplicates by taking a set
image_names = list(set(image_names))
print(len(image_names))

663


In [5]:
print(image_names[:2])

['imgs/test/bookstore_video1_11460.jpg', 'imgs/test/quad_video2_330.jpg']


In [6]:
score_threshold_retinanet = 0.4
acceptable_box_overlap = 0.3

### 2. Extract Ground Truth Annotations

In [7]:
def get_gt_annotations(filepath):
    gt_ann = []
    with open (os.path.join(cwd, annotations_file)) as csvfile:
        readCSV = csv.reader(csvfile, delimiter='\n')
        for row in readCSV:
            vals = row[0].split(',')
            this_filepath = vals[0]
            
            if this_filepath == filepath:
                record = []
                record.append(int(vals[1]))
                record.append(int(vals[2]))
                record.append(int(vals[3]))
                record.append(int(vals[4]))
                record.append(vals[5])
                
                gt_ann.append(record)
    return gt_ann    

### 3. IOU Calculations

In [8]:
def bb_intersection_over_union(boxA, boxB):
        # determine the (x, y)-coordinates of the intersection rectangle
        xA = max(int(boxA[0]), int(boxB[0]))
        yA = max(int(boxA[1]), int(boxB[1]))
        xB = min(int(boxA[2]), int(boxB[2]))
        yB = min(int(boxA[3]), int(boxB[3]))
 
        # compute the area of intersection rectangle
        interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
 
        # compute the area of both the prediction and ground-truth
        # rectangles
        boxAArea = (int(boxA[2]) - int(boxA[0]) + 1) * (int(boxA[3]) - int(boxA[1]) + 1)
        boxBArea = (int(boxB[2]) - int(boxB[0]) + 1) * (int(boxB[3]) - int(boxB[1]) + 1)
 
        # compute the intersection over union by taking the intersection
        # area and dividing it by the sum of prediction + ground-truth
        # areas - the interesection area
        iou = interArea / float(boxAArea + boxBArea - interArea)
 
        # return the intersection over union value
        return iou

In [9]:
def highest_iou(predicted_box, gt_ann):
    
    highest_iou = 0
    best_match = []
    for each_gt in gt_ann:
        this_iou = bb_intersection_over_union(predicted_box[:-1], each_gt[:-1])
#         print(" For this box is: ", this_iou)
        if this_iou > highest_iou:
            highest_iou = this_iou
            best_match = each_gt
            
    return highest_iou, best_match
        

In [10]:
def accuracy(gt_ann, predicted_ann):
    true_positive = []
    class_mismatch = []
    false_positive = []
    false_negative = []
    
    predicted_ann_copy = predicted_ann.copy()
    gt_ann_copy = gt_ann.copy()
    
    for each_pred in predicted_ann_copy:
        # 1. Calculate the highest_iou with any gt_box
        best_iou, best_gt_match = highest_iou(each_pred, gt_ann_copy)
#         print("Best IOU is: ", each_pred, best_gt_match, best_iou)
        
        ## If this box has a match
        if best_iou >= acceptable_box_overlap:
            ## if class label matches
            if best_gt_match[-1] == each_pred[-1]:
                ## This is a true positive
                true_positive.append(each_pred)
                ## remove this from predicted ann and gt
#                 print("Predicted Ann before:", predicted_ann)
                predicted_ann.remove(each_pred)
#                 print("Predicted Ann after:", predicted_ann)
                if best_gt_match in gt_ann:
                    gt_ann.remove(best_gt_match)
            elif best_gt_match[-1] != each_pred[-1]:
                ## this is a class mismatch
                class_mismatch.append(each_pred)
                ## remove this box
                predicted_ann.remove(each_pred)
                if best_gt_match in gt_ann:
                    gt_ann.remove(best_gt_match)
        
    ## If IOU is less than 0.5, leave as is

    ## Any predicted box is now false positive
    for remain_pred in predicted_ann:
        false_positive.append(remain_pred)

    ## Any remaining gt box is false negative
    for remain_gt in gt_ann:
        if remain_gt[-1] in ['Biker', 'Pedestrian', 'Car', 'Bus']:
            false_negative.append(remain_gt)

    return true_positive, class_mismatch, false_positive, false_negative       

### 3. Draw Predicted and GT Annotations

In [11]:
def run_detection_image(filepath):
    image = read_image_bgr(filepath)

    # copy to draw on
    draw = image.copy()
    draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB)

    # preprocess image for network
    image = preprocess_image(image)
    image, scale = resize_image(image)

    # process image
    start = time.time()
    boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))
    print("processing time: ", time.time() - start)

    # correct for image scale
    boxes /= scale
    
    gt_ann = get_gt_annotations(filepath)
    num_gt_ann = len(gt_ann)
    
    predicted_ann = []

    # visualize detections
    for box, score, label in zip(boxes[0], scores[0], labels[0]):
        # scores are sorted so we can break
        if score < score_threshold_retinanet:
            break
        record = []
        record.append(int(box[0]))
        record.append(int(box[1]))
        record.append(int(box[2]))
        record.append(int(box[3]))
        record.append(labels_to_names[label])
        print("Record is: ", record)
        predicted_ann.append(record)
    
    
#     print("predicted ann are: ", predicted_ann)
    ## Calculate boxes
    true_positive, class_mismatch, false_positive, false_negative = accuracy(gt_ann, predicted_ann)
    print(len(true_positive), len(class_mismatch), len(false_positive), len(false_negative))
    
    font = cv2.FONT_HERSHEY_SIMPLEX
    ## Draw theses on the image
    ## Draw true positive in green
    if len(true_positive) > 0:
        for each_true in true_positive:
            cv2.rectangle(draw,(each_true[0],each_true[1]),(each_true[2],each_true[3]),(0,255,0),3) #green
            cv2.putText(draw, each_true[-1], (each_true[0]-2, each_true[1]-2),font, 0.5,
                        (0,0,0),1,cv2.LINE_AA) # text in black
        
    ## Draw class mismatch in light blue   
    if len(class_mismatch) > 0:
        for each_mismatch in class_mismatch:
            cv2.rectangle(draw,(each_mismatch[0],each_mismatch[1]),(each_mismatch[2],each_mismatch[3]),(255,255,0),3) #green
            cv2.putText(draw, each_mismatch[-1], (each_mismatch[0]-2, each_mismatch[1]-2),font, 0.5,
                        (0,0,0),1,cv2.LINE_AA) # text in black
        
    
     ## Draw false positive  in blue 
    if len(false_positive) > 0:
        for each_fp in false_positive:
            cv2.rectangle(draw,(each_fp[0],each_fp[1]),(each_fp[2],each_fp[3]),(255,0,0),3) #green
            cv2.putText(draw, each_fp[-1], (each_fp[0]-2, each_fp[1]-2),font, 0.5,
                        (0,0,0),1,cv2.LINE_AA) # text in black
    
    
     ## Draw false negative in red 
    if len(false_negative) > 0:
        for each_fn in false_negative:
            cv2.rectangle(draw,(each_fn[0],each_fn[1]),(each_fn[2],each_fn[3]),(0,0,255),3) #green
            cv2.putText(draw, each_fn[-1], (each_fn[0]-2, each_fn[1]-2),font, 0.5,
                        (0,0,0),1,cv2.LINE_AA) # text in black
            
            
    ## Add key to the image
    cv2.putText(draw, "True Positive", (1200, 20),font, 0.8, (0,255,0),1,cv2.LINE_AA) 
    cv2.putText(draw, "Class Mismatch", (1200, 50),font, 0.8, (255,255,0),1,cv2.LINE_AA)
    cv2.putText(draw, "False Positive", (1200, 80),font, 0.8, (255,0,0),1,cv2.LINE_AA)
    cv2.putText(draw, "False Negtaive", (1200, 110),font, 0.8, (0,0,2550),1,cv2.LINE_AA)
    
    ### Save this image
    
    file, ext = os.path.splitext(filepath)
    image_name = file.split('/')[-1] + ext
    output_path = os.path.join('examples/results_test/', image_name)
    
    draw_conv = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB)
    cv2.imwrite(output_path, draw_conv)
    
    return len(true_positive), len(class_mismatch), len(false_positive), len(false_negative), num_gt_ann

    


In [12]:
total_true_positive = 0
total_class_mismatch = 0
total_false_positive = 0
total_false_negative = 0
total_gt = 0

In [13]:
cnt=1
for filepath in image_names:
    num_tp, num_cm , num_fp, num_fn, num_gt = run_detection_image(filepath)
    print("Ground Truth: ", num_gt)
    total_true_positive += num_tp
    total_class_mismatch += num_cm
    total_false_positive += num_fp
    total_false_negative += num_fn
    total_gt += num_gt
    cnt+=1
    if cnt==80:
        break
        


processing time:  15.802170753479004
Record is:  [1295, 721, 1327, 763, 'Pedestrian']
Record is:  [1295, 1032, 1344, 1069, 'Pedestrian']
Record is:  [1374, 203, 1409, 246, 'Pedestrian']
Record is:  [212, 489, 247, 530, 'Pedestrian']
Record is:  [696, 688, 730, 738, 'Pedestrian']
Record is:  [297, 471, 336, 534, 'Biker']
Record is:  [669, 345, 710, 384, 'Pedestrian']
Record is:  [282, 462, 306, 511, 'Pedestrian']
Record is:  [267, 168, 296, 202, 'Pedestrian']
Record is:  [743, 259, 796, 324, 'Pedestrian']
Record is:  [310, 429, 345, 481, 'Biker']
6 3 2 5
Ground Truth:  14
processing time:  3.221630573272705
Record is:  [1622, 703, 1666, 733, 'Pedestrian']
Record is:  [1775, 792, 1817, 830, 'Pedestrian']
Record is:  [1908, 621, 1936, 644, 'Pedestrian']
Record is:  [1726, 691, 1757, 715, 'Pedestrian']
Record is:  [1887, 659, 1916, 681, 'Pedestrian']
Record is:  [1804, 676, 1831, 705, 'Pedestrian']
Record is:  [1733, 705, 1763, 729, 'Pedestrian']
Record is:  [1767, 808, 1818, 840, 'Pedest

processing time:  3.8102996349334717
Record is:  [1278, 52, 1328, 94, 'Pedestrian']
Record is:  [1327, 146, 1375, 184, 'Pedestrian']
Record is:  [1288, 147, 1335, 193, 'Pedestrian']
Record is:  [1331, 112, 1374, 150, 'Pedestrian']
Record is:  [1280, 84, 1331, 126, 'Pedestrian']
Record is:  [1285, 117, 1336, 157, 'Pedestrian']
Record is:  [1444, 102, 1523, 162, 'Pedestrian']
Record is:  [587, 363, 639, 411, 'Pedestrian']
Record is:  [1409, 142, 1461, 201, 'Pedestrian']
Record is:  [1466, 733, 1518, 777, 'Pedestrian']
Record is:  [1040, 837, 1099, 891, 'Pedestrian']
Record is:  [1034, 883, 1094, 930, 'Pedestrian']
12 0 0 0
Ground Truth:  12
processing time:  3.6627650260925293
Record is:  [706, 315, 736, 366, 'Pedestrian']
Record is:  [1372, 968, 1405, 995, 'Pedestrian']
Record is:  [591, 1024, 642, 1074, 'Biker']
Record is:  [269, 277, 307, 333, 'Biker']
Record is:  [1371, 977, 1403, 1003, 'Pedestrian']
3 1 1 3
Ground Truth:  7
processing time:  4.1487274169921875
Record is:  [1325, 140

processing time:  3.779629945755005
Record is:  [1163, 980, 1230, 1077, 'Biker']
Record is:  [169, 887, 229, 954, 'Biker']
Record is:  [158, 1154, 227, 1227, 'Pedestrian']
Record is:  [344, 985, 412, 1081, 'Biker']
Record is:  [60, 832, 109, 908, 'Biker']
Record is:  [114, 823, 162, 896, 'Biker']
5 0 1 0
Ground Truth:  5
processing time:  4.47433614730835
Record is:  [1322, 138, 1378, 178, 'Pedestrian']
Record is:  [1335, 108, 1385, 150, 'Pedestrian']
Record is:  [1277, 141, 1337, 185, 'Pedestrian']
Record is:  [1289, 113, 1337, 152, 'Pedestrian']
Record is:  [1270, 84, 1337, 122, 'Pedestrian']
Record is:  [1437, 107, 1499, 163, 'Pedestrian']
Record is:  [1286, 37, 1342, 96, 'Pedestrian']
Record is:  [1551, 734, 1660, 806, 'Pedestrian']
Record is:  [834, 782, 906, 863, 'Pedestrian']
Record is:  [438, 928, 529, 1003, 'Biker']
Record is:  [1406, 143, 1476, 203, 'Pedestrian']
Record is:  [1806, 962, 1894, 1022, 'Pedestrian']
Record is:  [1206, 66, 1265, 117, 'Pedestrian']
Record is:  [137

processing time:  4.1274778842926025
Record is:  [1330, 111, 1372, 149, 'Pedestrian']
Record is:  [1288, 148, 1331, 185, 'Pedestrian']
Record is:  [1413, 145, 1464, 194, 'Pedestrian']
Record is:  [1291, 114, 1326, 150, 'Pedestrian']
Record is:  [1328, 138, 1366, 172, 'Pedestrian']
Record is:  [1443, 113, 1498, 163, 'Pedestrian']
Record is:  [1275, 59, 1327, 97, 'Pedestrian']
Record is:  [1054, 678, 1121, 753, 'Pedestrian']
Record is:  [584, 364, 640, 407, 'Pedestrian']
Record is:  [360, 309, 412, 363, 'Pedestrian']
Record is:  [1652, 876, 1747, 949, 'Biker']
Record is:  [1451, 738, 1507, 775, 'Pedestrian']
Record is:  [1541, 882, 1600, 925, 'Pedestrian']
12 0 1 0
Ground Truth:  12
processing time:  4.187922716140747
Record is:  [1287, 147, 1336, 192, 'Pedestrian']
Record is:  [1328, 147, 1375, 184, 'Pedestrian']
Record is:  [1333, 112, 1374, 148, 'Pedestrian']
Record is:  [1278, 52, 1328, 93, 'Pedestrian']
Record is:  [1286, 115, 1335, 155, 'Pedestrian']
Record is:  [1441, 104, 1520, 1

processing time:  3.822028160095215
Record is:  [768, 1274, 861, 1396, 'Biker']
Record is:  [212, 932, 297, 1050, 'Biker']
Record is:  [1106, 1132, 1188, 1220, 'Skater']
Record is:  [1049, 1191, 1124, 1298, 'Biker']
Record is:  [1000, 1250, 1082, 1347, 'Biker']
Record is:  [817, 1306, 864, 1393, 'Biker']
Record is:  [1317, 1115, 1381, 1211, 'Biker']
Record is:  [378, 515, 435, 612, 'Pedestrian']
Record is:  [245, 973, 282, 1048, 'Biker']
Record is:  [212, 933, 296, 1051, 'Pedestrian']
Record is:  [281, 1542, 324, 1605, 'Pedestrian']
Record is:  [1012, 1446, 1091, 1577, 'Biker']
Record is:  [202, 968, 243, 1039, 'Pedestrian']
Record is:  [1141, 1152, 1183, 1213, 'Biker']
Record is:  [746, 1281, 820, 1382, 'Biker']
Record is:  [220, 962, 270, 1049, 'Biker']
Record is:  [1106, 1132, 1188, 1220, 'Biker']
14 2 1 2
Ground Truth:  13
processing time:  3.8678016662597656
Record is:  [1330, 110, 1372, 149, 'Pedestrian']
Record is:  [1286, 148, 1331, 187, 'Pedestrian']
Record is:  [1330, 139, 13

processing time:  3.9903781414031982
Record is:  [1288, 150, 1332, 189, 'Pedestrian']
Record is:  [1327, 110, 1371, 148, 'Pedestrian']
Record is:  [1325, 140, 1371, 175, 'Pedestrian']
Record is:  [1290, 115, 1328, 150, 'Pedestrian']
Record is:  [1437, 113, 1506, 162, 'Pedestrian']
Record is:  [1269, 52, 1331, 92, 'Pedestrian']
Record is:  [1412, 149, 1463, 198, 'Pedestrian']
Record is:  [1015, 893, 1108, 952, 'Biker']
Record is:  [582, 351, 652, 408, 'Pedestrian']
Record is:  [1464, 737, 1516, 774, 'Pedestrian']
Record is:  [914, 882, 971, 935, 'Pedestrian']
Record is:  [1750, 877, 1838, 925, 'Pedestrian']
Record is:  [1287, 86, 1325, 123, 'Pedestrian']
Record is:  [1759, 916, 1844, 958, 'Pedestrian']
Record is:  [1249, 54, 1308, 93, 'Pedestrian']
14 0 1 1
Ground Truth:  14


In [15]:
print(total_true_positive, total_class_mismatch, total_false_positive, total_false_negative, total_gt)

656 66 85 174 845


gt_ann = get_gt_annotations(filepath)
print(gt_ann)
im = cv2.imread(filepath)
for each_ann in gt_ann:
    cv2.rectangle(im, (each_ann[0],each_ann[1]),(each_ann[2],each_ann[3]),(0,0,255),3)
cv2.imwrite('examples/results_test/pred_gt.jpg', im)

In [19]:
# 4763 687 2192 1067 5854

precision = total_true_positive/(total_true_positive+total_false_positive+total_class_mismatch)
recall = total_true_positive/(total_true_positive+ total_false_negative)
f1_score = 2*(precision * recall)/(precision + recall)
Accuracy=(total_true_positive+total_class_mismatch)/(total_true_positive+total_false_positive+total_class_mismatch+total_false_negative)
print("Accuracy:",Accuracy)

Accuracy: 0.7359836901121305


In [20]:
file1 = open(r"accuracy.txt", "w+")

In [22]:
file1.write(str(Accuracy))

18