In [1]:
from utils import *
from datasets import PascalVOCDataset
from tqdm import tqdm
from pprint import PrettyPrinter

import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data as td
import torchvision as tv
import pandas as pd
from PIL import Image
from matplotlib import pyplot as plt
import torchsummary
import traceback

import warnings
warnings.filterwarnings('ignore')

import torchvision.datasets.voc as voc

In [2]:
# Good formatting when printing the APs for each class and mAP
pp = PrettyPrinter()

In [3]:
# Parameters
data_folder = './'
keep_difficult = True  # difficult ground truth objects must always be considered in mAP calculation, because these objects DO exist!
batch_size = 64
workers = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = './BEST_checkpoint_ssd300.pth.tar'

In [4]:
# Load model checkpoint that is to be evaluated
checkpoint = torch.load(checkpoint)
model = checkpoint['model']
model = model.to(device)

In [5]:
def transform_voc(img, target, img_size=(800, 800)):
    """
    """

    transform = tv.transforms.Compose([
            tv.transforms.Resize(img_size), # Using default interpolation
            tv.transforms.ToTensor(),
            tv.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    x = transform(img)
    return x, target

In [6]:
import sys
if sys.version_info[0] == 2:
    import xml.etree.cElementTree as ET
else:
    import xml.etree.ElementTree as ET

class PascalVOCDataset1(voc.VOCDetection):
    """
    """
    
    def __getclasses__(self):
        classes = ('background',
                   'aeroplane',
                   'bicycle',
                   'bird',
                   'boat',
                   'bottle',
                   'bus',
                   'car',
                   'cat',
                   'chair',
                   'cow',
                   'diningtable',
                   'dog',
                   'horse',
                   'motorbike',
                   'person',
                   'pottedplant',
                   'sheep',
                   'sofa',
                   'train',
                   'tvmonitor')

        return classes
    
    def get_lbl_data(self, lbl, yscale=1, xscale=1):
        """
        """

        lbl_arr = np.asarray(self.__getclasses__())
        
        classes = np.zeros((64, 1))
        boxes = np.zeros((64, 4))

        if(type(lbl['annotation']['object']) == dict):
            obj = lbl['annotation']['object']
            bbox = obj['bndbox']
            name = obj['name']
            lbl_class = np.where(lbl_arr == name)[0][0]
            lbl_box = int(bbox['xmin'])//xscale, int(bbox['ymin'])//yscale, int(bbox['xmax'])//xscale, int(bbox['ymax'])//yscale
            classes[0] = lbl_class
            boxes[0] = lbl_box
        else:
            for j, obj in enumerate(lbl['annotation']['object']):
                bbox = obj['bndbox']
                name = obj['name']
                lbl_class = np.where(lbl_arr == name)[0][0]
                lbl_box = int(bbox['xmin'])//xscale, int(bbox['ymin'])//yscale, int(bbox['xmax'])//xscale, int(bbox['ymax'])//yscale
                classes[j] = lbl_class
                boxes[j] = lbl_box

        classes = torch.Tensor(classes).type(torch.float32)
        boxes = torch.Tensor(boxes).type(torch.float32)
        targets = torch.cat((classes, boxes), axis=1)
#         targets = {'classes': classes, 'boxes': boxes}
        return targets
        
    def draw_box(self, image, box, color):
        """Draw 3-pixel width bounding boxes on the given image array.
        color: list of 3 int values for RGB.
        """


        y1, x1, y2, x2 = box
        image[0, y1:y1 + 2, x1:x2] = color[0]
        image[0, y2:y2 + 2, x1:x2] = color[0]
        image[0, y1:y2, x1:x1 + 2] = color[0]
        image[0, y1:y2, x2:x2 + 2] = color[0]

        image[1, y1:y1 + 2, x1:x2] = color[1]
        image[1, y2:y2 + 2, x1:x2] = color[1]
        image[1, y1:y2, x1:x1 + 2] = color[1]
        image[1, y1:y2, x2:x2 + 2] = color[1]

        image[2, y1:y1 + 2, x1:x2] = color[2]
        image[2, y2:y2 + 2, x1:x2] = color[2]
        image[2, y1:y2, x1:x1 + 2] = color[2]
        image[2, y1:y2, x2:x2 + 2] = color[2]
        return image
    
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is the image segmentation.
        """
        img = Image.open(self.images[index]).convert('RGB')
        targets = self.parse_voc_xml(
            ET.parse(self.annotations[index]).getroot())

        yscale = img.size[1]/800
        xscale = img.size[0]/800
        
        if self.transforms is not None:
            img, targets = self.transforms(img, targets)
        
        targets = self.get_lbl_data(targets, yscale, xscale)
        
        return img, targets

    def number_of_classes(self):
        return len(self.__getclasses__())

In [7]:
# Switch to eval mode
model.eval()


# Load test data
#test_dataset = PascalVOCDataset1(root='VOC2012', year='2012', image_set='val', transforms=transform_voc, download=False)
test_dataset = PascalVOCDataset(data_folder,
                                split='train',
                                keep_difficult=keep_difficult)
#test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                                          collate_fn=test_dataset.collate_fn, num_workers=workers, pin_memory=True)

In [8]:
def evaluate(test_loader, model):
    """
    Evaluate.

    :param test_loader: DataLoader for test data
    :param model: model
    """

    # Make sure it's in eval mode
    model.eval()

    # Lists to store detected and true boxes, labels, scores
    det_boxes = list()
    det_labels = list()
    det_scores = list()
    true_boxes = list()
    true_labels = list()
    true_difficulties = list()  # it is necessary to know which objects are 'difficult', see 'calculate_mAP' in utils.py

    with torch.no_grad():
        # Batches
        for i, (images, boxes, labels, difficulties) in enumerate(tqdm(test_loader, desc='Evaluating')):
            images = images.to(device)  # (N, 3, 300, 300)

            # Forward prop.
            predicted_locs, predicted_scores = model(images)

            # Detect objects in SSD output
            det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects(predicted_locs, predicted_scores,
                                                                                       min_score=0.01, max_overlap=0.45,
                                                                                       top_k=200)
            # Evaluation MUST be at min_score=0.01, max_overlap=0.45, top_k=200 for fair comparision with the paper's results and other repos

            # Store this batch's results for mAP calculation
            boxes = [b.to(device) for b in boxes]
            labels = [l.to(device) for l in labels]
            difficulties = [d.to(device) for d in difficulties]

            det_boxes.extend(det_boxes_batch)
            det_labels.extend(det_labels_batch)
            det_scores.extend(det_scores_batch)
            true_boxes.extend(boxes)
            true_labels.extend(labels)
            true_difficulties.extend(difficulties)

        # Calculate mAP
        APs, mAP = calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties)

    # Print AP for each class
    pp.pprint(APs)

    print('\nMean Average Precision (mAP): %.3f' % mAP)

In [None]:
evaluate(test_loader, model)

Evaluating:  86%|████████▋ | 224/259 [1:21:18<11:45, 20.16s/it]