In [None]:
!which python

In [None]:
import sys
import torch
import pathlib
import logging
import argparse
import numpy as np
from tqdm import tqdm

from utils.misc import str2bool, Timer
from utils import box_utils, measurements
from src.open_images import OpenImagesDataset
from src.network import create_network, create_network_predictor

# from eval import group_annotation_by_class, compute_average_precision_per_class

In [None]:
parser = argparse.ArgumentParser(description="Evaluation.")
parser.add_argument("--trained_model", type=str)
parser.add_argument("--dataset",       type=str, help="The root directory of the VOC dataset or Open Images dataset.")
parser.add_argument("--label_file",    type=str, help="The label file path.")
parser.add_argument("--use_cuda",      type=str2bool, default=True)
parser.add_argument("--nms_method",    type=str, default="hard")
parser.add_argument("--iou_threshold", type=float, default=0.5, help="The threshold of Intersection over Union.")
parser.add_argument("--eval_dir",      default="eval_results", type=str, help="The directory to store evaluation results.")
parser.add_argument('--width_mult',    default=1.0, type=float,help='Width Multiplifier for Network')
parser.add_argument("--use_2007_metric", type=str2bool, default=True)

args   = parser.parse_args([])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")


In [None]:
args.dataset       = 'data/person_dog/'
args.trained_model = 'models/Epoch-10-Loss-2.66624203515709.pth'
args.label_file    = 'models/labels.txt'


In [None]:

def group_annotation_by_class(dataset):
    true_case_stat = {}
    all_gt_boxes = {}
    all_difficult_cases = {}
    for i in range(len(dataset)):
        image_id, annotation = dataset.get_annotation(i)
        gt_boxes, classes, is_difficult = annotation
        gt_boxes = torch.from_numpy(gt_boxes)
        for i, difficult in enumerate(is_difficult):
            class_index = int(classes[i])
            gt_box = gt_boxes[i]
            if not difficult:
                true_case_stat[class_index] = true_case_stat.get(class_index, 0) + 1

            if class_index not in all_gt_boxes:
                all_gt_boxes[class_index] = {}
            if image_id not in all_gt_boxes[class_index]:
                all_gt_boxes[class_index][image_id] = []
            all_gt_boxes[class_index][image_id].append(gt_box)
            if class_index not in all_difficult_cases:
                all_difficult_cases[class_index]={}
            if image_id not in all_difficult_cases[class_index]:
                all_difficult_cases[class_index][image_id] = []
            all_difficult_cases[class_index][image_id].append(difficult)

    for class_index in all_gt_boxes:
        for image_id in all_gt_boxes[class_index]:
            all_gt_boxes[class_index][image_id] = torch.stack(all_gt_boxes[class_index][image_id])
    for class_index in all_difficult_cases:
        for image_id in all_difficult_cases[class_index]:
            all_gt_boxes[class_index][image_id] = torch.tensor(all_gt_boxes[class_index][image_id])
    return true_case_stat, all_gt_boxes, all_difficult_cases


def compute_average_precision_per_class(num_true_cases, gt_boxes, difficult_cases,
                                        prediction_file, iou_threshold, use_2007_metric):
    with open(prediction_file) as f:
        image_ids = []
        boxes = []
        scores = []
        for line in f:
            t = line.rstrip().split("\t")
            image_ids.append(t[0])
            scores.append(float(t[1]))
            box = torch.tensor([float(v) for v in t[2:]]).unsqueeze(0)
            box -= 1.0  # convert to python format where indexes start from 0
            boxes.append(box)
        scores = np.array(scores)
        sorted_indexes = np.argsort(-scores)
        boxes = [boxes[i] for i in sorted_indexes]
        image_ids = [image_ids[i] for i in sorted_indexes]
        true_positive = np.zeros(len(image_ids))
        false_positive = np.zeros(len(image_ids))
        matched = set()
        for i, image_id in enumerate(image_ids):
            box = boxes[i]
            if image_id not in gt_boxes:
                false_positive[i] = 1
                continue

            gt_box = gt_boxes[image_id]
            ious = box_utils.iou_of(box, gt_box)
            max_iou = torch.max(ious).item()
            max_arg = torch.argmax(ious).item()
            if max_iou > iou_threshold:
                if difficult_cases[image_id][max_arg] == 0:
                    if (image_id, max_arg) not in matched:
                        true_positive[i] = 1
                        matched.add((image_id, max_arg))
                    else:
                        false_positive[i] = 1
            else:
                false_positive[i] = 1

    true_positive = true_positive.cumsum()
    false_positive = false_positive.cumsum()
    precision = true_positive / (true_positive + false_positive)
    recall = true_positive / num_true_cases
    if use_2007_metric:
        return measurements.compute_voc2007_average_precision(precision, recall)
    else:
        return measurements.compute_average_precision(precision, recall)


In [None]:
eval_path   = pathlib.Path(args.eval_dir)
eval_path.mkdir(exist_ok=True)
timer       = Timer()
class_names = [name.strip() for name in open(args.label_file).readlines()]
dataset     = OpenImagesDataset(args.dataset, dataset_type="test")
true_case_stat, all_gb_boxes, all_difficult_cases = group_annotation_by_class(dataset)
net         = create_network(len(class_names), width_mult=args.width_mult, is_test=True)

In [None]:
timer.start("Load Model")
net.load(args.trained_model)
net         = net.to(DEVICE)
predictor   = create_network_predictor(net, nms_method=args.nms_method, device=DEVICE)
print(f'It took {timer.end("Load Model")} seconds to load the model.')

In [None]:
results = []
for i in tqdm(range(len(dataset))):
    timer.start("Load Image")
    image = dataset.get_image(i)
    timer.start("Predict")
    boxes, labels, probs = predictor.predict(image)
    indexes = torch.ones(labels.size(0), 1, dtype=torch.float32) * i
    results.append(torch.cat([
        indexes.reshape(-1, 1),
        labels.reshape(-1, 1).float(),
        probs.reshape(-1, 1),
        boxes + 1.0  # matlab's indexes start from 1
    ], dim=1))
results = torch.cat(results)
for class_index, class_name in enumerate(class_names):
    if class_index == 0: continue  # ignore background
    prediction_path = eval_path / f"det_test_{class_name}.txt"
    with open(prediction_path, "w") as f:
        sub = results[results[:, 1] == class_index, :]
        for i in range(sub.size(0)):
            prob_box = sub[i, 2:].numpy()
            image_id = dataset.ids[int(sub[i, 0])]
            print(
                image_id + "\t" + " ".join([str(v) for v in prob_box]).replace(" ", "\t"),
                file=f
            )

In [None]:
aps = []
print("Average Precision Per-class:")
for class_index, class_name in enumerate(class_names):
    if class_index == 0:
        continue
    prediction_path = eval_path / f"det_test_{class_name}.txt"
    ap = compute_average_precision_per_class(
        true_case_stat[class_index],
        all_gb_boxes[class_index],
        all_difficult_cases[class_index],
        prediction_path,
        args.iou_threshold,
        args.use_2007_metric
    )
    aps.append(ap)
    print(f"{class_name}: {ap}")

print(f"\nAverage Precision Across All Classes: {sum(aps)/len(aps)}")

In [None]:
# Average Precision Per-class:
# Dog: 0.7447672840095942
# Person: 0.5551671470337526

# Average Precision Across All Classes: 0.6499672155216734