In [2]:
import json
import torchvision

train_root_path = '/hpi/fs00/share/fg/rabl/strassenburg/datasets/coco/train2017'
train_annotations = '/hpi/fs00/share/fg/rabl/strassenburg/datasets/coco/annotations/instances_train2017.json'

val_root_path = '/hpi/fs00/share/fg/rabl/strassenburg/datasets/coco/val2017'
val_annotations = '/hpi/fs00/share/fg/rabl/strassenburg/datasets/coco/annotations/instances_val2017.json'

In [3]:
%matplotlib inline

In [4]:
def id_to_class_index(annotations_path):
    index = {}
    with open(annotations_path, 'r') as COCO:
        js = json.loads(COCO.read())
        cats = js['categories']

        for cat in cats:
            index[cat['id']] = cat['name']

    return index


In [5]:
train_cat_index = id_to_class_index(train_annotations)
val_cat_index = id_to_class_index(val_annotations)

In [6]:
print(len(train_cat_index))
print(len(val_cat_index))

print(val_cat_index == train_cat_index)

80
80
True


In [7]:
print('all categories')
print(train_cat_index.values())

all categories
dict_values(['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'])


In [8]:
train_coco_data = torchvision.datasets.CocoDetection(train_root_path, train_annotations)
val_coco_data = torchvision.datasets.CocoDetection(val_root_path, val_annotations)

loading annotations into memory...
Done (t=11.15s)
creating index...
index created!
loading annotations into memory...
Done (t=0.40s)
creating index...
index created!


In [9]:
def category_ids(annotation):
    cat_ids = set()
    for a in annotation:
        cat_ids.add(a['category_id'])

    return cat_ids


def filter_number_of_categories(elements, num_categories):
    return [e for e in elements if len(category_ids(e[1])) <= num_categories]



In [10]:
print('val set')
val_three_cats = filter_number_of_categories(val_coco_data, 3)
print('num pictures with max 3 categories: {}'.format(len(val_three_cats)))
val_two_cats = filter_number_of_categories(val_three_cats, 2)
print('num pictures with max 2 categories: {}'.format(len(val_two_cats)))
val_one_cat = filter_number_of_categories(val_two_cats, 1)
print('num pictures with max 1 categories: {}'.format(len(val_one_cat)))




val set
num pictures with max 3 categories: 3575
num pictures with max 2 categories: 2595
num pictures with max 1 categories: 1073


In [11]:
print('train set')
train_three_cats = filter_number_of_categories(train_coco_data, 3)
print('num pictures with 3 categories: {}'.format(len(train_three_cats)))
train_two_cats = filter_number_of_categories(train_three_cats, 2)
print('num pictures with 2 categories: {}'.format(len(train_two_cats)))
train_one_cat = filter_number_of_categories(train_two_cats, 1)
print('num pictures with 1 categories: {}'.format(len(train_one_cat)))



train set
num pictures with 3 categories: 85234
num pictures with 2 categories: 61065
num pictures with 1 categories: 25207
