diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index e165b957c67..731461e8a8d 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -12,22 +12,23 @@ class LSUNClass(data.Dataset): - def __init__(self, db_path, transform=None, target_transform=None): + def __init__(self, root, transform=None, target_transform=None): import lmdb - self.db_path = db_path - self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False, + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + + self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin(write=False) as txn: self.length = txn.stat()['entries'] - cache_file = '_cache_' + db_path.replace('/', '_') + cache_file = '_cache_' + root.replace('/', '_') if os.path.isfile(cache_file): self.keys = pickle.load(open(cache_file, "rb")) else: with self.env.begin(write=False) as txn: self.keys = [key for key, _ in txn.cursor()] pickle.dump(self.keys, open(cache_file, "wb")) - self.transform = transform - self.target_transform = target_transform def __getitem__(self, index): img, target = None, None @@ -60,7 +61,7 @@ class LSUN(data.Dataset): `LSUN `_ dataset. Args: - db_path (string): Root directory for the database files. + root (string): Root directory for the database files. classes (string or list): One of {'train', 'val', 'test'} or a list of categories to load. e,g. ['bedroom_train', 'church_train']. transform (callable, optional): A function/transform that takes in an PIL image @@ -69,13 +70,16 @@ class LSUN(data.Dataset): target and transforms it. """ - def __init__(self, db_path, classes='train', + def __init__(self, root, classes='train', transform=None, target_transform=None): categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower'] dset_opts = ['train', 'val', 'test'] - self.db_path = db_path + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + if type(classes) == str and classes in dset_opts: if classes == 'test': classes = [classes] @@ -102,7 +106,7 @@ def __init__(self, db_path, classes='train', self.dbs = [] for c in self.classes: self.dbs.append(LSUNClass( - db_path=db_path + '/' + c + '_lmdb', + root=root + '/' + c + '_lmdb', transform=transform)) self.indices = [] @@ -112,7 +116,6 @@ def __init__(self, db_path, classes='train', self.indices.append(count) self.length = count - self.target_transform = target_transform def __getitem__(self, index): """ @@ -146,6 +149,7 @@ def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) + fmt_str += ' Classes: {}\n'.format(self.classes) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): '