In [7]:
import pandas as pd
import random
import urllib3
import os
import shutil
from PIL import Image

In [8]:
def download(url, path):
    http = urllib3.PoolManager()
    r = http.request('GET', url, preload_content=False)

    with open(path, 'wb') as out:
        while True:
            data = r.read(48000)
            if not data:
                break
            out.write(data)

    r.release_conn()

In [9]:
labels = [
    'airplane',
    'car', 
    'horse'
]
labels_ids = [
    '/m/0cmf2', 
    '/m/0k4j', 
    '/m/03k3r'
]
test_count = 10
train_count = 100
min_area = 0.2

In [10]:
# Test Images
# Test images include a segmentation image

segmentation = pd.read_csv("/Users/racoon/Desktop/open-images-v5-metadata/test-annotations-object-segmentation.csv")
bbox = pd.read_csv("/Users/racoon/Desktop/open-images-v5-metadata/test-annotations-bbox.csv")

test_subsets = [segmentation[(segmentation['LabelName']==labels_id)] for labels_id in labels_ids]

test_groups = [[] for _ in range(len(labels))]

for i in range(len(test_subsets)):
    rows = test_subsets[i]
    print(len(rows.index))
    for index, row in rows.iterrows():
        area = (row["BoxXMax"]-row["BoxXMin"]) * (row["BoxYMax"]-row["BoxYMin"])
        if area >= min_area:
            box = bbox[(bbox["ImageID"] == row["ImageID"]) & (bbox["Confidence"] == 1) & (bbox["IsDepiction"] == 0) & (bbox["IsGroupOf"] == 0) & (bbox["IsInside"] == 0) & (bbox["IsTruncated"] == 0)]
            if box.shape[0] == 1:
                test_groups[i].append((row["ImageID"], row["MaskPath"]))

test_groups = [random.sample(g, test_count) for g in test_groups]

for i, m in enumerate(test_groups):
    name = labels[i]
    os.makedirs(f'dataset_{train_count}/test/{name}', exist_ok=True)
    for (image_id, mask_path) in m:
        print(f'http://s3.amazonaws.com/open-images-dataset/test/{image_id}.jpg', f'dataset_{train_count}/test/{name}/{image_id}.jpg')
        download(f'http://s3.amazonaws.com/open-images-dataset/test/{image_id}.jpg', f'dataset_{train_count}/test/{name}/{image_id}.jpg')
        
        pil_image = Image.open(f'dataset_{train_count}/test/{name}/{image_id}.jpg')
        pil_mask = Image.open(f'/Users/racoon/Desktop/open-images-v5-metadata/test-masks-{mask_path[0]}/{mask_path}')
        pil_mask = pil_mask.resize(pil_image.size)
        pil_mask.save(f'dataset_{train_count}/test/{name}/{image_id}.mask.png')

852
850
860
http://s3.amazonaws.com/open-images-dataset/test/b6ac22d7db1769ee.jpg dataset_100/test/airplane/b6ac22d7db1769ee.jpg
http://s3.amazonaws.com/open-images-dataset/test/fbe835c5944f93e5.jpg dataset_100/test/airplane/fbe835c5944f93e5.jpg
http://s3.amazonaws.com/open-images-dataset/test/839ce813ca97084c.jpg dataset_100/test/airplane/839ce813ca97084c.jpg
http://s3.amazonaws.com/open-images-dataset/test/9dc879c35a26d2d3.jpg dataset_100/test/airplane/9dc879c35a26d2d3.jpg
http://s3.amazonaws.com/open-images-dataset/test/35b11a04c24db20c.jpg dataset_100/test/airplane/35b11a04c24db20c.jpg
http://s3.amazonaws.com/open-images-dataset/test/e95bc413d4b748ba.jpg dataset_100/test/airplane/e95bc413d4b748ba.jpg
http://s3.amazonaws.com/open-images-dataset/test/a48f1d15812036fa.jpg dataset_100/test/airplane/a48f1d15812036fa.jpg
http://s3.amazonaws.com/open-images-dataset/test/93b5bf58149adefd.jpg dataset_100/test/airplane/93b5bf58149adefd.jpg
http://s3.amazonaws.com/open-images-dataset/test/d54

In [11]:
# Train Images
validation = pd.read_csv("/Users/racoon/Desktop/open-images-v5-metadata/validation-annotations-bbox.csv")
subsets = [validation[(validation['LabelName'] == labels_id) & (validation["Confidence"] == 1) & (validation["IsDepiction"] == 0) & (validation["IsGroupOf"] == 0) & (validation["IsInside"] == 0) & (validation["IsTruncated"] == 0)] for labels_id in labels_ids]

label_image_ids = validation[validation['LabelName'].isin(labels_ids)]["ImageID"]
other_rows = validation[~validation['ImageID'].isin(label_image_ids)]

print(validation.shape, other_rows.shape, label_image_ids.shape)

candidates = [[] for _ in range(len(labels)+1)]
u = set()

for i in range(len(subsets)):
    rows = subsets[i]
    for index, row in rows.iterrows():
        area = (row["XMax"]-row["XMin"]) * (row["YMax"]-row["YMin"])
        if area >= min_area:
            u.add(row["ImageID"])
            candidates[i].append(row["ImageID"])

while len(candidates[-1]) < train_count:        
    row = other_rows.sample().iloc[0]

    if row["ImageID"] not in u:
        candidates[-1].append(row["ImageID"])

members = [random.sample(c, train_count) for c in candidates]

for i, m in enumerate(members):
    name = labels[i] if i < len(subsets) else 'other'
    os.makedirs(f'dataset_{train_count}/train/{name}', exist_ok=True)
    for id in m:
        print(f'http://s3.amazonaws.com/open-images-dataset/validation/{id}.jpg', f'dataset_{train_count}/train/{name}/{id}.jpg')
        download(f'http://s3.amazonaws.com/open-images-dataset/validation/{id}.jpg', f'dataset_{train_count}/train/{name}/{id}.jpg') 

(303980, 13) (250342, 13) (11373,)
http://s3.amazonaws.com/open-images-dataset/validation/2f118b9b64e097ab.jpg dataset_100/train/airplane/2f118b9b64e097ab.jpg
http://s3.amazonaws.com/open-images-dataset/validation/cf4d0c782e7fa7b5.jpg dataset_100/train/airplane/cf4d0c782e7fa7b5.jpg
http://s3.amazonaws.com/open-images-dataset/validation/875e00e5a095b197.jpg dataset_100/train/airplane/875e00e5a095b197.jpg
http://s3.amazonaws.com/open-images-dataset/validation/ac7668dafcb1ac70.jpg dataset_100/train/airplane/ac7668dafcb1ac70.jpg
http://s3.amazonaws.com/open-images-dataset/validation/eda11dda88e91dd3.jpg dataset_100/train/airplane/eda11dda88e91dd3.jpg
http://s3.amazonaws.com/open-images-dataset/validation/046ba24119dc2170.jpg dataset_100/train/airplane/046ba24119dc2170.jpg
http://s3.amazonaws.com/open-images-dataset/validation/3b902f0d20721ca6.jpg dataset_100/train/airplane/3b902f0d20721ca6.jpg
http://s3.amazonaws.com/open-images-dataset/validation/533b566abbbf65cd.jpg dataset_100/train/air

In [5]:
from PIL import Image
import numpy as np
pil_mask = Image.open("dataset_100/test/car/455c29cd8db5b225.mask.png").convert('1')
mask = np.array(pil_mask)
print(np.sum(mask))

252650
