In [1]:
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
import mmcv
import json
import glob
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageColor, ImageFilter, ImageDraw
import xml.etree.ElementTree as ET
from tqdm import tqdm

In [2]:
class F1Score:
    def __init__(self, IoU_thresh=0.5, area_thresh=1000):
        self.true_positive = 0
        self.false_positive = 0
        self.truth_count = 0
        self._iou_thresh = IoU_thresh
        self._area_thresh = area_thresh
        self._iou_total = 0
        self._iou_count = 0

    def update_state(self, n_truth, IoUs, areas=None):
        if areas is not None:
            IoUs = [i for (i, a) in zip(IoUs, areas) if a > self._area_thresh]
        self._iou_total += sum(IoUs)
        self._iou_count += len(IoUs)
        n_tp = sum(1 for i in IoUs if i >= self._iou_thresh)
        n_fp = len(IoUs) - n_tp
        self.true_positive += n_tp
        self.false_positive += n_fp
        self.truth_count += n_truth

    @property
    def precision(self):
        return self.true_positive / max(1, self.true_positive + self.false_positive)

    @property
    def recall(self):
        return self.true_positive / max(1, self.truth_count)

    @property
    def f1(self):
        return (2 * self.precision * self.recall) / max(1, self.precision + self.recall)

    @property
    def iou(self):
        return self._iou_total / max(1, self._iou_count)

In [3]:
def read_sample(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    bboxes = []
    for object_ in root.iter('object'):
        ymin, xmin, ymax, xmax = None, None, None, None
        for box in object_.findall("bndbox"):
            ymin = int(box.find("ymin").text)
            xmin = int(box.find("xmin").text)
            ymax = int(box.find("ymax").text)
            xmax = int(box.find("xmax").text)

        bbox = [xmin, ymin, xmax, ymax] # PASCAL VOC
        bboxes.append(bbox)
    im_file = root.find("filename").text
    return im_file, bboxes

In [6]:
def load_model():
    config_file = 'Config/cascade_mask_rcnn_hrnetv2p_w32_20e.py'
    checkpoint_file = 'epoch_36.pth'
    return init_detector(config_file, checkpoint_file, device='cuda:0')

In [27]:
def area(bbox):
    return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])


def detection_bbox_match(true_bbox, pred_bbox, label):
    pred_bbox = [x[:4] for x, y in zip(pred_bbox, label) if y == 0]
    true_matched = set()
    ious = list()
    for box_p in pred_bbox:
        best_iou = 0
        best_match = None
        for t, box_t in enumerate(true_bbox):
            if t in true_matched:
                continue
            I_xmin = max(box_p[0], box_t[0])
            I_xmax = min(box_p[2], box_t[2])
            I_ymin = max(box_p[1], box_t[1])
            I_ymax = min(box_p[3], box_t[3])
            
            if I_xmax <= I_xmin or I_ymax <= I_ymin:
                continue
            I = (I_xmax - I_xmin) * (I_ymax - I_ymin)
            
            U = area(box_p) + area(box_t) -  I 
            this_iou = I / U
            if this_iou > best_iou:
                best_iou = this_iou
                best_match = t
        if best_match is not None:
            ious.append(best_iou)
            true_matched.add(best_match)
    return ious

In [49]:
color_map = {
    0: 'red', 
    1: 'green',   # cell
    2: 'blue',  # table (bordered)
}

def show_results(bbox, labels, im_file, out_file):
    im = Image.open(im_file)
    draw = ImageDraw.Draw(im)
    for box, label in zip(bbox, labels):
        draw.rectangle(box[:4], outline=color_map[label], width=5)
    im.save(out_file)

In [12]:
model = load_model()

In [44]:
DATA_PATH = "/home/ubuntu/detection"
OUT_PATH = "/home/ubuntu/cascade_out"

test_set = list()
with open(f"{DATA_PATH}/test_filelist.txt") as fp:
    for line in fp:
        test_set.append(line.strip())

In [51]:
metrics = [F1Score(0.5), F1Score(0.75), F1Score(0.95)]
for xml_file in tqdm(test_set):
    page_id, true_bboxes = read_sample(f"{DATA_PATH}/{xml_file}")
    im_file = f"{DATA_PATH}/images/{page_id}"
    # Run Inference
    result = inference_detector(model, im_file)

    bbox_result, _ = result
    labels = [np.full(bbox.shape[0], i, dtype=np.int32) 
            for i, bbox in enumerate(bbox_result)]
    labels = np.concatenate(labels).tolist()
    bboxes = np.vstack(bbox_result).tolist()
    IoUs = detection_bbox_match(true_bboxes, bboxes, labels)
    for m in metrics:
        m.update_state(n_truth=len(true_bboxes), IoUs=IoUs)
    show_results(bboxes, labels, im_file, f"{OUT_PATH}/detection/{page_id}")

  "See the documentation of nn.Upsample for details.".format(mode))
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3809/3809 [1:00:16<00:00,  1.05it/s]


In [52]:
# for m in metrics:
#     print(m.precision, m.recall, m.f1)

# # 0.9807787087235091 0.1587554846429996 0.2732765723702279
# # 0.9615574174470183 0.1556441962504986 0.2679209008514144
# # 0.6604238541153278 0.10690067810131632 0.14119951567842667

0.9807787087235091 0.1587554846429996 0.2732765723702279
0.9615574174470183 0.1556441962504986 0.2679209008514144
0.6604238541153278 0.10690067810131632 0.14119951567842667
