In [7]:
import os
import xml.etree.ElementTree as ET
import numpy as np
from matplotlib import pyplot as plt


def extract_negatives(image_id):
    """
    Extract templates
    :param image_id: id of image to extract from
    """
    image_dir = 'datasets/JPEGImages'
    anno_dir = 'datasets/Annotations'
    image_file = os.path.join(image_dir, '{}.jpg'.format(image_id))
    anno_file = os.path.join(anno_dir, '{}.xml'.format(image_id))
    assert os.path.exists(image_file), '{} not found.'.format(image_file)
    assert os.path.exists(anno_file), '{} not found.'.format(anno_file)
    out_dir = 'negatives_same_scale'

    anno_tree = ET.parse(anno_file)
    objs = anno_tree.findall('object')
    occurrences = {'waldo': 0, 'wenda': 0, 'wizard': 0}
    image = np.asarray(plt.imread(image_file))
    bboxes = []
    w = h = count = 0
    for idx, obj in enumerate(objs):
        name = obj.find('name').text
        bbox = obj.find('bndbox')
        x1 = int(bbox.find('xmin').text)
        y1 = int(bbox.find('ymin').text)
        x2 = int(bbox.find('xmax').text)
        y2 = int(bbox.find('ymax').text)
        #lets just use the last w and h
        w = x2-x1
        h = y2-y1
        bboxes.append((x1, y1, w, h))
        count += 1
    for i in range(count):
        while True:
            done = False
            x = np.random.randint(0, image.shape[1] - w)
            y = np.random.randint(0, image.shape[0] - h)
            box = (x, y, w, h)
            failed = False
            for boxy in bboxes:
                if is_intersect(box, boxy):
                   failed = True
                   break
            if not failed:
                #save img
                plt.imsave('{}/{}_{}_2.jpg'.format(out_dir, image_id, i), image[y:y+h, x:x+w])
                break
    

def is_intersect(bbox1, bbox2):
    x1, y1, w1, h1 = bbox1
    x2, y2, w2, h2 = bbox2
    score = 0
    ### YOUR CODE HERE
    if x1 < x2:
        if x1 + w1 < x2:
            return 0
    if x2 < x1:
        if x2 + w2 < x1:
            return 0
    if y1 < y2:
        if y1 + h1 < y2:
            return 0
    if y2 < y1:
        if y2 + h2 < y1:
            return 0
    i_x1 = max(x1, x2)
    i_x2 = min(x1 + w1, x2 + w2)
    i_y1 = max(y1, y2)
    i_y2 = min(y1 + h1, y2 + h2)
    i = (i_x2 - i_x1) * (i_y2 - i_y1)
    return i > 0

In [8]:
filename = 'datasets/ImageSets/train.txt'
with open(filename) as f:
    for img_id in f.readlines():
        img_id = img_id.rstrip()
        extract_negatives(img_id)