In [None]:
import os, time
import cv2
import xml.etree.ElementTree as ET
from PIL import Image
import numpy as np
import random
from pathlib import Path
SEED_XML_DIR = '/data/darknet/python/xml/'
SEED_IMG_DIR = '/data/darknet/python/image/'
GENE_IMG_DIR = '../generated/image_fillter_obj/'
GENE_IMG_DIR = '../generated/xml_fillter_obj/'
# RT:RightTop 
# LB:LeftBottom 
# bbox: [xmin, xax, ymin, ymax]
def IOU(bbox_a, bbox_b):
    '''
    W = min(A.RT.x, B.RT.x) - max(A.LB.x, B.LB.x) 
    H = min(A.RT.y, B.RT.y) - max(A.LB.y, B.LB.y) 
    if W <= 0 or H <= 0: 
        return 0 
    SA = (A.RT.x - A.LB.x) * (A.RT.y - A.LB.y) 
    SB = (B.RT.x - B.LB.x) * (B.RT.y - B.LB.y) 
    cross = W * H return cross/(SA + SB - cross)
    '''
    W = min(bbox_a[1], bbox_b[1]) - max(bbox_a[0], bbox_b[0]) 
    H = min(bbox_a[3], bbox_b[3]) - max(bbox_a[2], bbox_b[2]) 
    if W <= 0 or H <= 0: 
        return 0
    SA = (bbox_a[1] - bbox_a[0]) * (bbox_a[3] - bbox_a[2]) 
    SB = (bbox_b[1] - bbox_b[0]) * (bbox_b[3] - bbox_b[2])  
    cross = W * H 
    return cross/(SA + SB - cross)

def get_obj_from_xml(xml):
    in_file = open(xml_name)
    tree=ET.parse(in_file)
    root = tree.getroot()
    return [obj for obj in root.iter('object')]

def get_obj_from_image_file(file, bbox):
    img = cv2.imread(file)
    img_obj = img[int(bbox[2]):int(bbox[3]), int(bbox[0]):int(bbox[1])]
    return img_obj

def get_bboxes_from_etree(etree):
    root = tree.getroot()  
    objects = root.findall('object')
    bboxes = []
    for obj in objects:
        xmlbox = obj.find('bndbox')
        b = [float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)]
        bboxes.append(b)
    return bboxes
def past_to_background_from_image_file(file, bboxes, background_img_array, extend_spaces=0):
    
    img = cv2.imread(file)
    #img = img - 50
    if(img.shape != background_img_array.shape):
        print('shape not match')
        return
    #print(img.shape)
    #print(img)
    img_objs = []
    for bbox in bboxes:
        img_obj = img[int(bbox[2]):int(bbox[3]), int(bbox[0]):int(bbox[1])]
        img_objs.append(img_obj)
    i = 0
    for bbox in bboxes:
        background_img_array[int(bbox[2]):int(bbox[3]), int(bbox[0]):int(bbox[1])] = img_objs[i]
        i = i+1
    cv2.imwrite(GENE_IMG_DIR+file.split('/')[-1], background_img_array)
    return 

seed_xml_names = os.listdir(SEED_XML_DIR)
seed_xml_names.sort()


classes_ = ['bus', 'car', 'truck', 'motorbike', 'bicycle', 'person']
temp_index_array = [900]
last_frame_bboxes = []
for xml_name in seed_xml_names:
    #print(xml_name)       
    '''
    head = xml_name[:8]
    if(not (int(head) in temp_index_array)):
        continue
    '''
    if(not xml_name.endswith('.xml')):
        continue
    
    img_data = cv2.imread(SEED_IMG_DIR + xml_name[:-3] + 'jpg')
    in_file = open(SEED_XML_DIR + xml_name)
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    
    this_frame_bboxes = []
    boxes = []
    objs = []
    b = []
    #wrong_detection_objs = [[130, 311, 226, 475]]
    #objs_iter = root.iter('object')
    N = 0
    array_object = root.findall('object')
    for obj in array_object:
        erro_obj_removed = False
        stable_bbox_removed = False
        difficult = obj.find('difficult').text
        cls_ = obj.find('name').text
        if cls_ not in classes_ or int(difficult)==1:
            continue
        xmlbox = obj.find('bndbox')
        b = [float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)]
        this_frame_bboxes.append(b)
        if(len(last_frame_bboxes)>0):
            for stable_bbox in last_frame_bboxes:
                if(IOU(stable_bbox, b) > 0.85):
                    #print(b)
                    root.remove(obj)
                    stable_bbox_removed = True
                    N = N+1
                    break
            if(stable_bbox_removed):
                continue
        #filter size which smaller than 16X16 fit for SSD300
        bbox_width = b[1] - b[0]
        bbox_heigt = b[3] - b[2]
        if(not(bbox_width > 16 or bbox_heigt > 16)):
            print('remove:{0}:{1}'.format(bbox_width, bbox_heigt))
            root.remove(obj)
            continue 
        else:
            print('add:{0}:{1}'.format(bbox_width, bbox_heigt))
        boxes.append(b)
        objs += obj
    last_frame_bboxes = this_frame_bboxes.copy()  
    #print(boxes)
    
    if(len(boxes) == 0):
        continue
    past_to_background_from_image_file(SEED_IMG_DIR + xml_name[:-3] + 'jpg', boxes, img_data)
    tree.write(GENE_XML_DIR + xml_name)