Convert Imagenet Dataset to Shards
==============================

Imagenet data is laid out as a directory containing annotations in the `ILSVRC2012_devkit_*` subdirectories and individual image files in the `train` and `val` subdirectories.

A simple conversion would just consist of a command like `tar cf train.tar train`. 

WebDataset makes it possible to associate annotations directly with each image, making reading and processing the data much simpler. This means that we need to put the code that associates annotations with images into the conversion script (i.e., here). The code below is fairly inelegant, but it gets the job done.

The output is a tar file containing data like this:

    -r--r--r-- bigdata/bigdata   2 2019-10-30 22:17 n02096437_3246.cls
    -r--r--r-- bigdata/bigdata 16426 2019-10-30 22:17 n02096437_3246.jpg
    -r--r--r-- bigdata/bigdata    62 2019-10-30 22:17 n02096437_3246.json
    -r--r--r-- bigdata/bigdata     3 2019-10-30 22:17 n03240683_4321.cls
    -r--r--r-- bigdata/bigdata 124353 2019-10-30 22:17 n03240683_4321.jpg
    -r--r--r-- bigdata/bigdata    402 2019-10-30 22:17 n03240683_4321.json
    -r--r--r-- bigdata/bigdata      3 2019-10-30 22:17 n02091032_4199.cls
    -r--r--r-- bigdata/bigdata 133184 2019-10-30 22:17 n02091032_4199.jpg
    -r--r--r-- bigdata/bigdata     42 2019-10-30 22:17 n02091032_4199.json
    
After this conversion, data can then directly be accessed like:

    for sample in WebDataset("imagenet-000000.tar"):
        image, annotations, cls = sample["jpg"], sample["json"], sample["cls"]
        ...
        
Or:

    for image, cls in WebDataset("imagenet-000000.tar", extensions="jpg cls"):
        ...

In [1]:
%cd /mdata/imagenet-raw

/mdata/imagenet-raw


In [2]:
import os, sys, glob, os.path, sqlite3
import random as pyr
import re
import PIL.Image
import numpy as np
import io
import xmltodict
import warnings
import simplejson
import itertools as itt
import random

def readfile(path, mode="rb"):
    with open(path, mode) as stream:
        return stream.read()
def writefile(path, data):
    mode = "w" if isinstance(data, str) else "wb"
    with open(path, mode) as stream:
        stream.write(data)
def pilreads(data):
    stream = io.BytesIO(data)
    return np.array(PIL.Image.open(stream))

In [3]:
jpegs = sorted(glob.glob("train/*/*.JPEG"))
print(len(jpegs), len(glob.glob("train/*/*.xml")))

1281167 544546


In [4]:
import scipy.io
meta = scipy.io.loadmat("ILSVRC2012_devkit_t12/data/meta.mat")
meta = meta["synsets"]
def scalar(x):
    for i in range(10):
        if isinstance(x, str): break
        try: x = x[0]
        except: break
    return x
wnid2id = {scalar(l[0][1]): int(scalar(l[0][0])) for l in meta}
wnid2cname = {scalar(l[0][1]): str(scalar(l[0][2])) for l in meta}
print(list(wnid2id.items())[:5])
print(list(wnid2cname.items())[:5])

[('n02119789', 1), ('n02100735', 2), ('n02110185', 3), ('n02096294', 4), ('n02102040', 5)]
[('n02119789', 'kit fox, Vulpes macrotis'), ('n02100735', 'English setter'), ('n02110185', 'Siberian husky'), ('n02096294', 'Australian terrier'), ('n02102040', 'English springer, English springer spaniel')]


In [5]:
mode = "train"
def pathinfo(path):
    global mode
    if mode=="val":
        match = re.search(r"^[a-z]*/([^/]+)/ILSVRC2012_val_(\d+)\.JPEG", path)
    elif mode=="train":
        match = re.search(r"^[a-z]*/([^/]+)/\1_(\d+)\.JPEG", path)
    return match.group(1), int(match.group(2))
print(jpegs[3])
pathinfo(jpegs[3])

train/n01440764/n01440764_10040.JPEG


('n01440764', 10040)

In [6]:
def pathkey(path):
    return re.sub('.JPEG$', '', re.sub('.*/', '', path))

pathkey(jpegs[3])

'n01440764_10040'

In [7]:
def pathcls(path):
    return wnid2id[pathinfo(path)[0]]

pathcls(jpegs[3])

449

In [8]:
def jpeginfo(path):
    xmlpath = re.sub(".JPEG$", ".xml", path)
    if not os.path.exists(xmlpath):
        info = {}
    else:
        xml = readfile(xmlpath, "r")
        info = xmltodict.parse(xml)
    folder = pathinfo(path)[0]
    info["cls"] = wnid2id[folder]
    info["cname"] = wnid2cname[folder]
    return info

infos = [jpeginfo(jpegs[i]) for i in range(100)]
infos = list(filter(lambda a: a is not None, infos))
print(simplejson.dumps(infos[0], indent=4))

{
    "cls": 449,
    "cname": "tench, Tinca tinca"
}


In [9]:
from webdataset import writer
from importlib import reload
reload(writer)

def write_shards(dest, jpegs, maxsize=1e9):
    jpegs = jpegs.copy()
    random.shuffle(jpegs)
    sink = writer.ShardWriter(dest, maxsize=maxsize, encoder=False)
    for i, fname in enumerate(jpegs):
        key = pathkey(fname)
        jpeg = readfile(fname)
        info = jpeginfo(fname)
        cls = pathcls(fname)    
        if info is None: info = dict(cls=cls)
        assert cls == info["cls"]
        json = simplejson.dumps(info)
        if i%1000==0: print(i, key, len(jpeg), json[:50])
        sample = dict(__key__=key,
                      jpg=jpeg,
                      json=json.encode("utf-8"),
                      cls=str(cls).encode("utf-8"))
        sink.write(sample)
    sink.close()

In [None]:
write_shards("imagenet_train-%06d.tar", jpegs)

# writing imagenet_train-000000.tar 0 0.0 GB 0
0 n02398521_89849 153209 {"cls": 167, "cname": "hippopotamus, hippo, river 
1000 n02113023_5478 119980 {"cls": 197, "cname": "Pembroke, Pembroke Welsh co
2000 n03709823_4491 20618 {"cls": 818, "cname": "mailbag, postbag"}
3000 n03710637_28365 113928 {"annotation": {"folder": "n03710637", "filename":
4000 n01688243_6710 67245 {"cls": 468, "cname": "frilled lizard, Chlamydosau
5000 n03721384_2038 72796 {"cls": 339, "cname": "marimba, xylophone"}
6000 n02268853_5633 111716 {"annotation": {"folder": "n02268853", "filename":
7000 n02277742_4064 97637 {"annotation": {"folder": "n02277742", "filename":
8000 n03220513_13770 139727 {"cls": 897, "cname": "dome"}
# writing imagenet_train-000001.tar 8640 1.0 GB 8640
9000 n02423022_488 110651 {"cls": 12, "cname": "gazelle"}
10000 n04131690_13611 102261 {"annotation": {"folder": "n04131690", "filename":
11000 n03388043_19775 250215 {"annotation": {"folder": "n03388043", "filename":
12000 n03240683_3388 

In [None]:
jpegs = sorted(glob.glob("val/*/*.JPEG"))
print(len(jpegs), len(glob.glob("val/*/*.xml")))

In [None]:
mode = "val"
write_shards("imagenet_val-%06d.tgz", jpegs, maxsize=1e11)