The goal of this script is to read in all images and corresponding xml files in the dataset and then create a list of all objects in the dataset, their class, position, corresponding image filename and sub-index in the xml file.

In [1]:
import os
import tqdm
import cv2
import xmltodict
import io
from collections import namedtuple
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
%matplotlib inline
os.environ['CUDA_VISIBLE_DEVICES'] = ""

In [None]:
DATA_PATH = '/raid/group-data/uc150429/AID_DATA_201803/'
IMG_PATH = os.path.join(DATA_PATH,'original_image')
BBOX_PATH = os.path.join(DATA_PATH, '2d_bounding_box')

In [None]:
idx_bbox = set([name.replace('.xml','') for name in os.listdir(BBOX_PATH)])
idx_img = set([name.replace('.png','') for name in os.listdir(IMG_PATH)])

In [None]:
idxs = idx_bbox.intersection(idx_img)
print(len(idxs))

In [None]:
IndexEntry = namedtuple('IndexEntry', ['img_path', 'sub_idx', 'classname', 'left', 'top', 'right', 'bottom'], verbose=False);
IndexEntry('a', 1, 'b', 1,2,3,4)

The purpose of this function is to given a bounding box, increase it's size in all directions, but staying within valid coordinates

In [None]:
def extend_ltwh(l,t,w,h):
    lnew = max(l-(w>>1),0)
    bnew = min(1207, t+h+(h>>1))
    tnew = max(0, t-(h>>1))
    rnew = min(1919, l+w+(w>>1))
    return lnew, tnew, rnew, bnew

The purpose of this file is to, given an identifier idx, extract all the information about object in the corresponding files

In [None]:
def get_index_entries_from_file(idx):
    def intfloat(s):
        return int(float(s))

    xml_path = os.path.join(BBOX_PATH,idx+'.xml')
    img_path = os.path.join(IMG_PATH,idx+'.png')
    fh = io.open(xml_path, 'r', encoding='utf-8-sig')
    xmlo = xmltodict.parse(fh.read())
    
    entries = []
    
    print( xmlo['bboxes']['bbox'])
    if type(xmlo['bboxes']['bbox'])!=list:
        return entries
    for subidx, bbox in enumerate(xmlo['bboxes']['bbox']):
        ###print(bbox,'\n')
        c = bbox['@class']
        ###print(c)
        l, t, r, b = ( intfloat(bbox['left']), intfloat(bbox['top']), intfloat(bbox['right']), intfloat(bbox['bottom']) )
        ###print(l, t, r, b)
        l,t,r,b = extend_ltwh(l,t,r-l, b-t)
        ###print(l,t,r,b)
        entries.append(IndexEntry(img_path = img_path,
                                  sub_idx = subidx,
                          classname = c,
                          left = l,
                          top = t,
                          right = r,
                          bottom = b,
                         ))
    return entries

plot an entry

In [None]:
def plot_index_entry(entry):
    img = cv2.imread(entry.img_path)
    plt.figure()
    plt.imshow(img[entry.top:entry.bottom,entry.left:entry.right,::-1]);

Get all entries

In [None]:
all_entries = []
count=0
for idx in tqdm(list(idxs)):
    count+=1
    #print(idx)
    all_entries+=get_index_entries_from_file(idx)
print(count)

Save list of entries to disk

In [None]:
import pickle
with open('/raid/user-data/lscheucher/projects/bounding_box_classifier/full_object_index.pickle', 'wb') as f:
    pickle.dump(all_entries, f)

#TODO, must modfify

In [None]:
import pickle
with open('/raid/user-data/lscheucher/projects/bounding_box_classifier/full_object_index.pickle', 'wb') as f:
    pickle.dump(all_entries, f)