Skip to content

Commit

Permalink
Add TinyImageNet
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Aug 25, 2020
1 parent 02e53f7 commit 52bbe70
Showing 1 changed file with 67 additions and 5 deletions.
72 changes: 67 additions & 5 deletions tensorpack/dataflow/dataset/ilsvrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@
import os
import tarfile
import tqdm
from pathlib import Path

from ...utils import logger
from ...utils.fs import download, get_dataset_path, mkdir_p
from ...utils.loadcaffe import get_caffe_pb
from ...utils.timer import timed_operation
from ..base import RNGDataFlow

__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files']
__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files', 'TinyImageNet']

CAFFE_ILSVRC12_URL = ("http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz", 17858008)


class ILSVRCMeta(object):
"""
Provide methods to access metadata for ILSVRC dataset.
Provide methods to access metadata for :class:`ILSVRC12` dataset.
"""

def __init__(self, dir=None):
Expand Down Expand Up @@ -178,8 +179,11 @@ def __iter__(self):

class ILSVRC12(ILSVRC12Files):
"""
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999].
The label map follows the synsets.txt file in http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz.
The ILSVRC12 classification dataset, aka the commonly used 1000 classes ImageNet subset.
This dataflow produces uint8 images of shape [h, w, 3(BGR)], and a label between [0, 999].
The label map follows the synsets.txt file in
http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz,
which can also be queried using :class:`ILSVRCMeta`.
"""
def __init__(self, dir, name, meta_dir=None,
shuffle=None, dir_structure=None):
Expand Down Expand Up @@ -287,6 +291,64 @@ def parse_bbox(fname):
return ret


class TinyImageNet(RNGDataFlow):
"""
The TinyImageNet classification dataset, with 200 classes and 500 images
per class. See https://tiny-imagenet.herokuapp.com/.
It produces [image, label] where image is a 64x64x3(BGR) image, label is an
integer in [0, 200).
"""
def __init__(self, dir, name, shuffle=None):
"""
Args:
dir (str): a directory
name (str): one of 'train' or 'val'
shuffle (bool): shuffle the dataset.
Defaults to True if name=='train'.
"""
assert name in ['train', 'val'], name
dir = Path(os.path.expanduser(dir))
assert os.path.isdir(dir), dir
self.full_dir = dir / name
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle

with open(dir / "wnids.txt") as f:
wnids = [x.strip() for x in f.readlines()]
cls_to_id = {name: id for id, name in enumerate(wnids)}
assert len(cls_to_id) == 200

self.imglist = []
if name == 'train':
for clsid, cls in enumerate(wnids):
cls_dir = self.full_dir / cls / "images"
for img in cls_dir.iterdir():
self.imglist.append((str(img), clsid))
else:
with open(self.full_dir / "val_annotations.txt") as f:
for line in f:
line = line.strip().split()
img, cls = line[0], line[1]
img = self.full_dir / "images" / img
clsid = cls_to_id[cls]
self.imglist.append((str(img), clsid))

def __len__(self):
return len(self.imglist)

def __iter__(self):
idxs = np.arange(len(self.imglist))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self.imglist[k]
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
yield [im, label]


try:
import cv2
except ImportError:
Expand All @@ -297,7 +359,7 @@ def parse_bbox(fname):
meta = ILSVRCMeta()
# print(meta.get_synset_words_1000())

ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False)
ds = TinyImageNet('~/data/tiny-imagenet-200', 'val', shuffle=False)
ds.reset_state()

for _ in ds:
Expand Down

0 comments on commit 52bbe70

Please sign in to comment.