In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pylab
import skimage.io as io

from pytorchcocotools.coco import COCO

pylab.rcParams["figure.figsize"] = (8.0, 10.0)

In [None]:
data_dir = "../data"
data_type = "val2017"
ann_file = f"{data_dir}/annotations/instances_{data_type}.json"

In [None]:
# initialize COCO api for instance annotations
coco = COCO(ann_file)

In [None]:
# display COCO categories and supercategories
cats = coco.loadCats(coco.getCatIds())
nms = [cat["name"] for cat in cats]
print("COCO categories: \n{}\n".format(" ".join(nms)))

nms = {cat["supercategory"] for cat in cats}
sup_cat = " ".join(nms)
print(f"COCO supercategories: \n{sup_cat}")

In [None]:
# get all images containing given categories, select one at random
cat_ids = coco.getCatIds(catNms=["person", "dog", "skateboard"])
img_ids = coco.getImgIds(catIds=cat_ids)
img_ids = coco.getImgIds(imgIds=[324158])
img = coco.loadImgs(img_ids[np.random.randint(0, len(img_ids))])[0]

In [None]:
# load and display image
# img = io.imread('%s/images/%s/%s'%(dataDir,dataType,img['file_name']))
# use url to load image
image = io.imread(img["coco_url"])
plt.axis("off")
plt.imshow(image)
plt.show()

In [None]:
# load and display instance annotations
plt.imshow(image)
plt.axis("off")
ann_ids = coco.getAnnIds(imgIds=img["id"], catIds=cat_ids, iscrowd=None)
anns = coco.loadAnns(ann_ids)
coco.showAnns(anns)

In [None]:
# initialize COCO api for person keypoints annotations
ann_file = f"{data_dir}/annotations/person_keypoints_{data_type}.json"
coco_kps = COCO(ann_file)

In [None]:
# load and display keypoints annotations
plt.imshow(image)
plt.axis("off")
ax = plt.gca()
ann_ids = coco_kps.getAnnIds(imgIds=img["id"], catIds=cat_ids, iscrowd=None)
anns = coco_kps.loadAnns(ann_ids)
coco_kps.showAnns(anns)

In [None]:
# initialize COCO api for caption annotations
ann_file = f"{data_dir}/annotations/captions_{data_type}.json"
coco_caps = COCO(ann_file)

In [None]:
# load and display caption annotations
ann_ids = coco_caps.getAnnIds(imgIds=img["id"])
anns = coco_caps.loadAnns(ann_ids)
coco_caps.showAnns(anns)
plt.imshow(image)
plt.axis("off")
plt.show()