In [1]:
import os
import xml.etree.ElementTree as ET
import numpy as np
from matplotlib import pyplot as plt


def extract_pos(image_id):
    """
    Extract templates
    :param image_id: id of image to extract from
    """
    image_dir = 'datasets/JPEGImages'
    anno_dir = 'datasets/Annotations'
    image_file = os.path.join(image_dir, '{}.jpg'.format(image_id))
    anno_file = os.path.join(anno_dir, '{}.xml'.format(image_id))
    assert os.path.exists(image_file), '{} not found.'.format(image_file)
    assert os.path.exists(anno_file), '{} not found.'.format(anno_file)
    waldo_pos_path = 'cascade_waldo/p'
    wenda_pos_path = 'cascade_wenda/p'
    wizard_pos_path = 'cascade_wizard/p'
    waldo_neg_path = 'cascade_waldo/n'
    wenda_neg_path = 'cascade_wenda/n'
    wizard_neg_path = 'cascade_wizard/n'

    anno_tree = ET.parse(anno_file)
    objs = anno_tree.findall('object')
    image = np.asarray(plt.imread(image_file))
    waldo_count = wenda_count = wizard_count = 0
    waldo_bboxes = []
    wenda_bboxes = []
    wizard_bboxes = []
    min_h = image.shape[0]
    min_w = image.shape[1]
    max_w = max_h = 0
    for idx, obj in enumerate(objs):
        name = obj.find('name').text
        bbox = obj.find('bndbox')
        x1 = int(bbox.find('xmin').text)
        y1 = int(bbox.find('ymin').text)
        x2 = int(bbox.find('xmax').text)
        y2 = int(bbox.find('ymax').text)
        h = y2 - y1
        w = x2 - x1
        bbox = (x1, y1, w, h)
        if w < min_w:
            min_w = w
        if w > max_w:
            max_w = w
        if h < min_h:
            min_h = h
        if h > max_h:
            max_h = h
        if name == 'waldo':
            img_name = '{}{}_{}.jpg'.format(name, image_id, waldo_count)
            waldo_bboxes.append(bbox)
            waldo_count += 1
            plt.imsave('{}/{}'.format(waldo_pos_path, img_name), image[y1:y2, x1:x2])
        elif name == 'wenda':
            img_name = '{}{}_{}.jpg'.format(name, image_id, wenda_count)
            wenda_bboxes.append(bbox)
            wenda_count += 1
            plt.imsave('{}/{}'.format(wenda_pos_path, img_name), image[y1:y2, x1:x2])
        elif name == 'wizard':
            img_name = '{}{}_{}.jpg'.format(name, image_id, wizard_count)
            wizard_bboxes.append(bbox)
            wizard_count += 1
            plt.imsave('{}/{}'.format(wizard_pos_path, img_name), image[y1:y2, x1:x2])
        else:
            print('name {} is invalid'.format(name))
    # get negatives within scale of 0.2 to 0.7
    # go up to max_iter times at every scale interval to find one successful negative at that scale
    max_scale = 0.7
    min_scale = 0.2
    max_iter = 5
    scale = max_scale
    iter = 0
    j = 0
    while True:
        w = int(image.shape[1] * scale)
        h = int(image.shape[0] * scale)
        # negatives should be bigger
        x = np.random.randint(0, image.shape[1] - w)
        y = np.random.randint(0, image.shape[0] - h)
        box = (x, y, w, h)
        waldo_failed = wenda_failed = wizard_failed = False
        for boxy in waldo_bboxes:
            if is_intersect(box, boxy):
               waldo_failed = True
               break
        for boxy in wenda_bboxes:
            if is_intersect(box, boxy):
               wenda_failed = True
               break
        for boxy in wizard_bboxes:
            if is_intersect(box, boxy):
               wizard_failed = True
               break
        if not waldo_failed or not wenda_failed or not wizard_failed:
            #save img
            img_name = '{}{}_{}.jpg'.format('neg', image_id, j)
            if not waldo_failed:
                plt.imsave('{}/{}'.format(waldo_neg_path, img_name), image[y:y+h, x:x+w])
            if not wenda_failed:
                plt.imsave('{}/{}'.format(wenda_neg_path, img_name), image[y:y+h, x:x+w])
            if not wizard_failed:
                plt.imsave('{}/{}'.format(wizard_neg_path, img_name), image[y:y+h, x:x+w])
            j += 1
            if scale <= min_scale:
                break
            else:
                iter = 0
                scale -= 0.1
        else:
            iter += 1
            if iter >= max_iter:
                if scale <= min_scale:
                    break
                else:
                    iter = 0
                    scale -= 0.1
                    

def is_intersect(bbox1, bbox2):
    x1, y1, w1, h1 = bbox1
    x2, y2, w2, h2 = bbox2
    if x1 < x2:
        if x1 + w1 < x2:
            return 0
    if x2 < x1:
        if x2 + w2 < x1:
            return 0
    if y1 < y2:
        if y1 + h1 < y2:
            return 0
    if y2 < y1:
        if y2 + h2 < y1:
            return 0
    i_x1 = max(x1, x2)
    i_x2 = min(x1 + w1, x2 + w2)
    i_y1 = max(y1, y2)
    i_y2 = min(y1 + h1, y2 + h2)
    i = (i_x2 - i_x1) * (i_y2 - i_y1)
    return i > 0

In [None]:
filename = 'datasets/ImageSets/train.txt'
with open(filename) as f:
    for img_id in f.readlines():
        extract_pos(img_id.rstrip())
filename = 'datasets/ImageSets/val.txt'
with open(filename) as f:
    for img_id in f.readlines():
        extract_pos(img_id.rstrip())