In [4]:
import os,datetime
import cv2
import xml.etree.ElementTree as ET
from PIL import Image
import numpy as np
import random
from pathlib import Path

time_str = (datetime.datetime.now()).strftime("%Y%m%d")
xml_dir = '/data/darknet/python/' + time_str + '_xmls_replaced_ordinary_obj_with_bkg/'
image_dir = '/data/darknet/python/' + time_str + '_images_replaced_ordinary_obj_with_bkg/'
generated_image_dir = '/data/darknet/python/' + time_str + '_image_expand_sample/'
generated_xml_dir = '/data/darknet/python/' + time_str + '_xml_expand_sample/'

os.system('mkdir -p ' + xml_dir)
os.system('mkdir -p ' + image_dir)
os.system('mkdir -p ' + generated_image_dir)
os.system('mkdir -p ' + generated_xml_dir)

SEED_XML_DIR = xml_dir
SEED_IMG_DIR = image_dir
GENE_IMG_DIR = generated_image_dir
GENE_XML_DIR = generated_xml_dir
BKGD_IMG_DIR = '/data/darknet/python/20190603_background/'

# RT:RightTop 
# LB:LeftBottom 
# bbox: [xmin, xax, ymin, ymax]
def IOU(bbox_a, bbox_b):
    '''
    W = min(A.RT.x, B.RT.x) - max(A.LB.x, B.LB.x) 
    H = min(A.RT.y, B.RT.y) - max(A.LB.y, B.LB.y) 
    if W <= 0 or H <= 0: 
        return 0 
    SA = (A.RT.x - A.LB.x) * (A.RT.y - A.LB.y) 
    SB = (B.RT.x - B.LB.x) * (B.RT.y - B.LB.y) 
    cross = W * H return cross/(SA + SB - cross)
    '''
    W = min(bbox_a[1], bbox_b[1]) - max(bbox_a[0], bbox_b[0]) 
    H = min(bbox_a[3], bbox_b[3]) - max(bbox_a[2], bbox_b[2]) 
    if W <= 0 or H <= 0: 
        return 0
    SA = (bbox_a[1] - bbox_a[0]) * (bbox_a[3] - bbox_a[2]) 
    SB = (bbox_b[1] - bbox_b[0]) * (bbox_b[3] - bbox_b[2])  
    cross = W * H 
    return cross/(SA + SB - cross)

def get_obj_from_xml(xml_name):
    in_file = open(xml_name)
    tree=ET.parse(in_file)
    root = tree.getroot()
    return [obj for obj in root.iter('object')]

def get_obj_from_image_file(file, bbox):
    img = cv2.imread(file)
    img_obj = img[int(bbox[2]):int(bbox[3]), int(bbox[0]):int(bbox[1])]
    return img_obj

def get_bboxes_from_etree(etree):
    root = etree.getroot()  
    objects = root.findall('object')
    bboxes = []
    for obj in objects:
        xmlbox = obj.find('bndbox')
        b = [float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)]
        bboxes.append(b)
    return bboxes
def past_to_background_from_image_file(file, bboxes, background_img_array, extend_spaces=0):
    
    img = cv2.imread(file)
    #img = img - 50
    if(img.shape != background_img_array.shape):
        print('shape not match')
        return
    #print(img.shape)
    #print(img)
    img_objs = []
    for bbox in bboxes:
        img_obj = img[int(bbox[2]):int(bbox[3]), int(bbox[0]):int(bbox[1])]
        img_objs.append(img_obj)
    i = 0
    for bbox in bboxes:
        background_img_array[int(bbox[2]):int(bbox[3]), int(bbox[0]):int(bbox[1])] = img_objs[i]
        i = i+1
    cv2.imwrite(GENE_IMG_DIR+file.split('/')[-1], background_img_array)
    return 
def generate_new_xmlobj(xmlobj_old, new_position, new_size):
    element_object = ET.Element('object')
    tag_name = ET.SubElement(element_object, 'name')
    tag_name.text = xmlobj_old.find('name').text

    tag_difficult = ET.SubElement(element_object, 'difficult')
    tag_difficult.text = xmlobj_old.find('difficult').text

    element_bndbox = ET.SubElement(element_object, 'bndbox')
    tag_xmin = ET.SubElement(element_bndbox, 'xmin')
    tag_ymin = ET.SubElement(element_bndbox, 'ymin')
    tag_xmax = ET.SubElement(element_bndbox, 'xmax')
    tag_ymax = ET.SubElement(element_bndbox, 'ymax')
    tag_xmin.text = str(new_position[0])
    tag_ymin.text = str(new_position[1])
    tag_xmax.text = str(new_position[0] + new_size[1])
    tag_ymax.text = str(new_position[1] + new_size[0])
    return element_object
    
def past_and_insert(img_obj, img_array, new_position, obj_element, etree):
    new_xmlobj = generate_new_xmlobj(obj_element, new_position, img_obj.shape[:2])
    new_xml_etree = insert_to_xml(new_xmlobj, etree)
    new_pil_img = past_obj_to_background(img_obj, img_array, new_position)
    return new_pil_img, new_xml_etree
    

def insert_to_xml(xml_obj, xml_etree):
    root = xml_etree.getroot()
    root.append(xml_obj)
    return xml_etree
    
def past_obj_to_background(img_obj, img_array, position = (200, 200)):
    img = Image.fromarray(img_array)
    img_obj = Image.fromarray(img_obj)
    img.paste(img_obj, position)
    return img
    
def get_cls_from_xmlobj(obj_element):    
    return obj_element.find('name').text

def get_bbox_from_xmlobj(obj_element):
    xmlbox = obj_element.find('bndbox')
    return [float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), 
            float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)]
def generate_new_position(img_size, img_obj_size):
    array_x = np.arange(int(img_size[1] - img_obj_size[1]))
    array_y = np.arange(int(img_size[0] - img_obj_size[0]))
    random_x = random.sample(list(array_x), 1)[0]
    random_y = random.sample(list(array_y), 1)[0]
    new_position = (random_x, random_y)
    #print(new_position)
    return new_position
def generate_new_bbox(img_size, img_obj_size):
    array_x = np.arange(int(img_size[1] - img_obj_size[1]))
    array_y = np.arange(int(img_size[0] - img_obj_size[0]))
    random_x = random.sample(list(array_x), 1)[0]
    random_y = random.sample(list(array_y), 1)[0]
    new_position = (random_x, random_y)
    new_bbox = [random_x, random_x + img_obj_size[1], random_y, random_y + img_obj_size[0]]
    #print(new_bbox)
    return new_bbox
    
    
def inset_obj_to_an_image_and_xml(img_obj, img, obj_element, etree, keep_position=True):
    if not keep_position == True:
        new_bbox = generate_new_bbox(img.shape[:2], img_obj.shape[:2])
        bboxes = get_bboxes_from_etree(etree)
        retry_times = 0
        while(not check_bbox(new_bbox, bboxes)):
            print('new_bbox not suitable, retry...')
            retry_times = retry_times + 1
            if(retry_times > 50):
                return False, False
            new_bbox = generate_new_bbox(img.shape[:2], img_obj.shape[:2])
        print('new_bbox succussful')
    else:
        new_bbox = get_bbox_from_xmlobj(obj_element)   
    new_pil_img, new_xml_etree = past_and_insert(img_obj, img, (int(new_bbox[0]), int(new_bbox[2])), obj_element, etree)
    
    return new_pil_img, new_xml_etree
def check_bbox(new_bbox, bboxes):
    for bbox in bboxes:
        if(IOU(new_bbox, bbox) > 0.005):
            return False
        continue
    return True

def rotate_img(img, thealta):
    (h_, w_) = img.shape[:2]
    point_list = [(0, 0), (0, h_), (w_, h_), (w_, 0)]
    center = (w_ // 2, h_ // 2)
    roted_point_list = [((point[0]-center[0])*np.cos(np.pi*thealta/180) 
                         - (point[1]-center[1])*np.sin(np.pi*thealta/180) 
                         + center[0], (point[0]-center[0])*np.sin(np.pi*thealta/180) 
                         + (point[1]-center[1])*np.cos(np.pi*thealta/180) + center[1]) for point in point_list]
    #print(roted_point_list)

    temp = np.zeros((2, 4))
    temp[0] = [roted_point[0] for roted_point in roted_point_list]
    temp[1] = [roted_point[1] for roted_point in roted_point_list]
    (xmin, xmax, ymin, ymax) = (np.min(temp[0]), np.max(temp[0]), np.min(temp[1]), np.max(temp[1]))
    print('xmin: {0}, xmax: {1}, ymin: {2}, ymax: {3}'.format(xmin, xmax, ymin, ymax))

    roted_h, roted_w = ymax - ymin, xmax - xmin
    print('roted_h: {0}, roted_w: {1}'.format(roted_w, roted_h))
    
    top_bottom, left_right = int((roted_h - h_)/2), int((roted_w - w_)/2)
    print('top_bottom:{0}, left_right_:{1}'.format(top_bottom, left_right))
    padding = lambda arg : max(arg, 0)
    #dst = cv2.copyMakeBorder(img, padding(top_bottom), padding(top_bottom), padding(left_right), 
    #padding(left_right), cv2.BORDER_CONSTANT)
    dst = cv2.copyMakeBorder(img, padding(top_bottom), padding(top_bottom), padding(left_right), 
                             padding(left_right), cv2.BORDER_CONSTANT)
    
    (h, w) = dst.shape[:2]   
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, thealta, 1.0)
    rotated_img = cv2.warpAffine(dst, M, (w, h))
    if(top_bottom < 0):
        top_bottom = abs(top_bottom)
        rotated_img = rotated_img[top_bottom:-top_bottom, :, :]
    if(left_right < 0):
        left_right = abs(left_right)
        rotated_img = rotated_img[:, left_right:-left_right, :]
    return rotated_img

def generate_etree(etree_old):
    root_old = etree_old.getroot()
    
    root = ET.Element('annotation')
    root.append(root_old.find('folder'))
    root.append(root_old.find('filename'))
    root.append(root_old.find('source'))
    root.append(root_old.find('size'))
    root.append(root_old.find('segmented'))
    
    tree = ET.ElementTree(root)
    
    return tree

seed_xml_names = os.listdir(SEED_XML_DIR)
seed_xml_names.sort()
bkgd_img_names = os.listdir(BKGD_IMG_DIR)
bkgd_img_names.sort()
#classes = ['bus', 'car', 'truck', 'motorbike', 'bicycle', 'person']
classes = ['bus', 'car', 'truck']
for i in range(1):
    for xml_name in seed_xml_names:
        img_name = SEED_IMG_DIR + xml_name[:-3] + 'jpg'
        xml_name = SEED_XML_DIR + xml_name
        #print(img_name)
        #print(xml_name[-15:])
        
        tree = ET.parse(xml_name)
        img = cv2.imread(img_name)
        objs = get_obj_from_xml(xml_name)

        new_tree = generate_etree(tree)
        #print(BKGD_IMG_DIR + img_name[-15:-4] + '.jpg')
        
        if ((int(xml_name[-15:-4])) < 6005600):
            new_img = cv2.imread(BKGD_IMG_DIR + bkgd_img_names[2])
        elif (6005600 < (int(xml_name[-15:-4])) < 6005900):       
            new_img = cv2.imread(BKGD_IMG_DIR + bkgd_img_names[1])
        else:
            continue
        print(new_img.shape)
        flag_found = False
        num_obj = len(objs)
        array_num = np.arange(num_obj)
        if(len(array_num) == 0):
            continue
        print(array_num)
        '''
        sparse_ratio = 1
        random_num_obj = int((random.sample(list(array_num), 1)[0])/sparse_ratio)
        random_id_obj =  random.sample(list(array_num), random_num_obj)
        random_id_obj.sort()
        '''
        random_id_obj = list(array_num)
        print(random_id_obj)
        
        for i in random_id_obj:
            obj_element = objs[i]
            cls = get_cls_from_xmlobj(obj_element)
            if(cls not in classes):
                continue
            flag_found = True
            bbox = get_bbox_from_xmlobj(obj_element)           
            img_obj = get_obj_from_image_file(img_name, bbox)      
            new_pil_img, new_xml_etree = inset_obj_to_an_image_and_xml(img_obj, new_img, obj_element, new_tree)
            if((new_pil_img == False) or (new_xml_etree == False)):
                continue
            new_tree = new_xml_etree
            new_img = np.array(new_pil_img)
        ## save xml and img
        if(flag_found == False):
            continue
        
        time_mark = datetime.datetime.now()
        time_str = time_mark.strftime("%Y%m%d%H%M%S%f_")
        new_pil_img = Image.fromarray(new_img[:,:,(2,1,0)])
        new_pil_img.save(GENE_IMG_DIR + time_str + img_name.split('/')[-1])
        new_tree.write(GENE_XML_DIR + time_str + xml_name.split('/')[-1])
        #break

(1080, 1920, 3)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
(1080, 1920, 3)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
(1080, 1920, 3)
[ 0  1  2  3  4  5  6  7  8  9 10]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
(1080, 1920, 3)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
(1080, 1920, 3)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
(1080, 1920, 3)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
(1080, 1920, 3)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
(1080, 1920, 3)
[0 1 2 3 4]
[0, 1, 2, 3, 4]
(1080, 1920, 3)
[0 1 2 3 4 5]
[0, 1, 2, 3, 4, 5]
(108