## Open Images Dataset Extractor

***Extract and download a subset of images for specific classes from Open Images Dataset together with bounding box labels***

- Using Open Images Dataset version: `V6` - https://storage.googleapis.com/openimages/web/download.html
- If running on SageMaker Notebook Instance use `conda_python3` kernel

In [None]:
!wget https://raw.githubusercontent.com/openimages/dataset/master/downloader.py

In [None]:
import os
import time
import random
import importlib

from matplotlib import pyplot as plt
import cv2

import utils

In [None]:
importlib.reload(utils)

In [None]:
with open('class-descriptions-boxable.csv') as fin:
    all_classes = fin.read().strip().split('\n')
print(f'Loaded {len(all_classes)} class names')

In [None]:
EXTRACT_CLASSES = ['Man', 'Woman']
COLOURS = {c:[random.randint(0, 255) for _ in range(3)] for c in EXTRACT_CLASSES}
DATA_ROOT = 'data'
IMAGE_EXT = '.jpg'

TRAIN_SAMPLE_COUNT = 10000
VALIDATION_SAMPLE_COUNT = int(TRAIN_SAMPLE_COUNT / 10)

TRAIN_SPLIT = ('train', 'https://storage.googleapis.com/openimages/v6/oidv6-train-annotations-bbox.csv', 
               TRAIN_SAMPLE_COUNT)
VALIDATION_SPLIT = ('validation', 'https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv', 
                    VALIDATION_SAMPLE_COUNT)

CLASS_LOOKUP = {cl[0]: cl[1] for cl in [c.split(',') for c in all_classes] if cl[1] in EXTRACT_CLASSES}

print(f'Selected classes: {EXTRACT_CLASSES}')

In [None]:
def extract(split, url, count_per_class, download_images=True, class_lookup=CLASS_LOOKUP, data_root=DATA_ROOT):
    print(f'--> Extract split [{split}]')
    annot = os.path.split(url)[-1]
    if not os.path.exists(annot):
        print('\t- download annotations file')
        !wget $url

    for cid in class_lookup.keys():
        cls = class_lookup[cid]
        print(f'- Processing class {cls}')
        extracted = {}
        to_download = []
        remaining_count = count_per_class

        print('\t- extracting')
        with open(annot) as fin:
            rec = fin.readline()
            while rec != '':
                rec = rec.strip().split(',')
                if len(rec) > 2 and cid == rec[2]:
                    img_id = rec[0]

                    # Get all the bounding boxes, even if we got the required image count
                    if remaining_count > 0:
                        if img_id not in extracted:
                            extracted[img_id] = []
                            to_download.append(img_id)
                            remaining_count -= 1
                        add = True
                    else:
                        add = img_id in extracted
                    if add:
                        # Source box is in (x1, x2, y1, y2), move around
                        extracted[img_id].append(' '.join([rec[4], rec[6], rec[5], rec[7]]))  

                rec = fin.readline()

        split_path = os.path.join(data_root, split)  
        class_path = os.path.join(split_path, cls)
        label_path = os.path.join(class_path, 'NormalisedLabel')
        os.makedirs(label_path, exist_ok=True)

        image_list_path = os.path.join(split_path, f'{cls}_images.txt')
        with open(image_list_path, 'w') as fout:
            fout.write(split + '/' + f'\n{split}/'.join(to_download))

        for img_id in extracted.keys():
            with open(os.path.join(label_path, f'{img_id}.txt'), 'w') as fout:
                boxes = [f'{cls} {b}' for b in extracted[img_id]]
                fout.write('\n'.join(boxes))

        if download_images:
            print('\t- download images')
            !python downloader.py $image_list_path --download_folder=$class_path --num_processes=5

        denorm_label_path = os.path.join(class_path, 'Label')
        print(f'\t- denormalise into {denorm_label_path}', end='') 
        if not os.path.exists(denorm_label_path):
            os.makedirs(denorm_label_path, exist_ok=True)

        count = 0
        for img in os.listdir(class_path):
            if img.endswith(IMAGE_EXT):
                
                img_id = os.path.splitext(img)[0]   
                labels = extracted[img_id]
                
                denorm_labels = []
                image = cv2.imread(os.path.join(class_path, img))
                for box in labels:
                    box = box.split(' ')
                    denorm_labels.append(' '.join([cls, *[str(b) for b in utils.denormalise_box(box, image)]]))

                with open(os.path.join(denorm_label_path, img.replace(IMAGE_EXT, '.txt')), 'w') as fout:
                    fout.write('\n'.join(denorm_labels))     
                
                count += 1
        print(f' - {count} images')
    print('Done')       

In [None]:
CURRENT_SPLIT = TRAIN_SPLIT

In [None]:
extract(*CURRENT_SPLIT, download_images=False)

##### Show some samples

In [None]:
def show_samples(split, classes=EXTRACT_CLASSES, max_per_class=3):
    show_count = 3
    _, ax = plt.subplots(ncols=len(classes), nrows=max_per_class, figsize=(20, 20))

    for cls_i, cls in enumerate(classes):
        print(f'--> Processing class {cls}')
        count = 0
        images_path = os.path.join(DATA_ROOT, split, cls)
        labels_path = os.path.join(images_path, 'Label')
        color = COLOURS[cls]

        for image_name in os.listdir(images_path):
            if image_name.endswith('.jpg'):

                print(f'--> Processing {image_name}')
                image = cv2.imread(os.path.join(images_path, image_name))
                with open(os.path.join(labels_path, image_name.replace('.jpg', '.txt'))) as fin:
                    labels = fin.read().strip().split('\n')
                lbl = [l.split(' ') for l in labels]    

                for c, x1, y1, x2, y2 in lbl:
                    x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2))
                    utils.draw_box(image, [x1, y1, x2, y2], c, color)

                ax[count, cls_i].imshow(image)

                count += 1
                if count >= max_per_class:
                    break
    plt.show()            

In [None]:
show_samples(CURRENT_SPLIT[0], max_per_class=10)