Skip to content

Commit

Permalink
add places365 dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Nov 1, 2020
1 parent ecf525d commit 36bdc18
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tensorpack/dataflow/dataset/ilsvrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def __init__(self, dir, name, meta_dir=None,
dir (str): A directory containing a subdir named ``name``,
containing the images in a structure described below.
name (str): One of 'train' or 'val' or 'test'.
shuffle (bool): shuffle the dataset.
Defaults to True if name=='train'.
shuffle (bool): shuffle the dataset. Defaults to True if name=='train'.
dir_structure (str): One of 'original' or 'train'.
The directory structure for the 'val' directory.
'original' means the original decompressed directory, which only has list of image files (as below).
Expand Down Expand Up @@ -354,6 +353,7 @@ def __iter__(self):
except ImportError:
from ...utils.develop import create_dummy_class
ILSVRC12 = create_dummy_class('ILSVRC12', 'cv2') # noqa
TinyImageNet = create_dummy_class('TinyImageNet', 'cv2') # noqa

if __name__ == '__main__':
meta = ILSVRCMeta()
Expand Down
81 changes: 81 additions & 0 deletions tensorpack/dataflow/dataset/places.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#-*- coding: utf-8 -*-

import os
import numpy as np

from ...utils import logger
from ..base import RNGDataFlow


class Places365Standard(RNGDataFlow):
"""
The Places365-Standard Dataset, in low resolution format only.
Produces BGR images of shape (256, 256, 3) in range [0, 255].
"""
def __init__(self, dir, name, shuffle=None):
"""
Args:
dir: path to the Places365-Standard dataset in its "easy directory
structure". See http://places2.csail.mit.edu/download.html
name: one of "train" or "val"
shuffle (bool): shuffle the dataset. Defaults to True if name=='train'.
"""
assert name in ['train', 'val'], name
dir = os.path.expanduser(dir)
assert os.path.isdir(dir), dir
self.name = name
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle

label_file = os.path.join(dir, name + ".txt")
all_files = []
labels = set()
with open(label_file) as f:
for line in f:
filepath = os.path.join(dir, line.strip())
line = line.strip().split("/")
label = line[1]
all_files.append((filepath, label))
labels.add(label)
self._labels = sorted(list(labels))
# class ids are sorted alphabetically:
# https://github.com/CSAILVision/places365/blob/master/categories_places365.txt
labelmap = {label: id for id, label in enumerate(self._labels)}
self._files = [(path, labelmap[x]) for path, x in all_files]
logger.info("Found {} images in {}.".format(len(self._files), label_file))

def get_label_names(self):
"""
Returns:
[str]: name of each class.
"""
return self._labels

def __len__(self):
return len(self._files)

def __iter__(self):
idxs = np.arange(len(self._files))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self._files[k]
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
yield [im, label]


try:
import cv2
except ImportError:
from ...utils.develop import create_dummy_class
Places365Standard = create_dummy_class('Places365Standard', 'cv2') # noqa

if __name__ == '__main__':
from tensorpack.dataflow import PrintData
ds = Places365Standard("~/data/places365_standard/", 'train')
ds = PrintData(ds, num=100)
ds.reset_state()
for k in ds:
pass

0 comments on commit 36bdc18

Please sign in to comment.