In [1]:
import xml.etree.ElementTree as ET
import glob
import os
import cv2
from tqdm import tqdm

In [5]:
DATA_DIR = 'dataset/train'
ANNOTATION_FOLDER_NAME = 'xml'
IMAGE_FOLDER_NAME = 'images'

In [6]:
def get_image_path_from_xml_path(xml_path, annotation_folder_name, images_folder_name):
    """
    Get image path from xml path by replace annotation folder name to images folder name
    """
    
    return xml_path.replace("/{}/".format(annotation_folder_name), "/{}/".format(images_folder_name)).replace('.xml', '.jpg')

def load_xml(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    
    objects = []
    boxes = []
    size = None
    
    for element in root:
        if element.tag == 'size':
            width = int(element.find('width').text)
            height = int(element.find('height').text)
            size = (width, height)
        if element.tag == 'object':
            n_ele = element.find('name').text
            if n_ele == None: 
                continue
            
            bndbox = element.find('bndbox')
            xmin = int(bndbox.find('xmin').text)
            ymin = int(bndbox.find('ymin').text)
            xmax = int(bndbox.find('xmax').text)
            ymax = int(bndbox.find('ymax').text)
            
            objects.append(n_ele)
            boxes.append((xmin, ymin, xmax, ymax))
        
    return size, objects, boxes

def corect(xml_path):
    """
    Check size in xml content and corect if it wrong.
    Check coordination in boundingbox and corect if it wrong.
    """
    tree = ET.parse(xml_path)
    root = tree.getroot()
    
    objects = []
    boxes = []
    size = None
    
    image_path = get_image_path_from_xml_path(xml_path, ANNOTATION_FOLDER_NAME, IMAGE_FOLDER_NAME)
#     print(image_path)
    height, width = cv2.imread(image_path).shape[:2]
    
    is_dirty = False
    for element in root:
        if element.tag == 'size':
            xml_width = element.find('width')
            xml_height = element.find('height')
            if width != int(xml_width.text):
                xml_width.text = str(width)
                is_dirty = True
            if height != int(xml_height.text):
                xml_height.text = str(height)
                is_dirty = True

        if element.tag == 'object':
            n_ele = element.find('name').text
            if n_ele == None: 
                continue
            
            bndbox = element.find('bndbox')
            xml_xmin = bndbox.find('xmin')
            xml_ymin = bndbox.find('ymin')
            xml_xmax = bndbox.find('xmax')
            xml_ymax = bndbox.find('ymax')
            if int(xml_xmin.text) < 0:
                xml_xmin.text = str(0)
                is_dirty = True
            if int(xml_ymin.text) < 0:
                xml_ymin.text = str(0)
                is_dirty = True
            if int(xml_xmax.text) > width:
                xml_xmax.text = str(width)
                is_dirty = True
            if int(xml_ymax.text) > height:
                xml_ymax.text = str(height)
                is_dirty = True
    if is_dirty:
        tree.write(xml_path)
        print('saved to: ', xml_path)

In [None]:
xmls = glob.glob(os.path.join(DATA_DIR, ANNOTATION_FOLDER_NAME, '*.xml'))

for xml in tqdm(xmls):
    corect(xml)

In [None]:
xmls = glob.glob(os.path.join(DATA_DIR, ANNOTATION_FOLDER_NAME, '*.xml'))

for xml in tqdm(xmls):
    size, _, boxes = load_xml(xml)
    if size is None:
        print('None size: ', xml)
        continue
    for box in boxes:
        xmin, ymin, xmax, ymax = box
        if xmin < 0 or ymin < 0:
            print(xml)
        if xmax > size[0] or ymax > size[1]:
            print(xml)