Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions torchvision/datasets/lsun.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,23 @@


class LSUNClass(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
def __init__(self, root, transform=None, target_transform=None):
import lmdb
self.db_path = db_path
self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False,
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform

self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries']
cache_file = '_cache_' + db_path.replace('/', '_')
cache_file = '_cache_' + root.replace('/', '_')
if os.path.isfile(cache_file):
self.keys = pickle.load(open(cache_file, "rb"))
else:
with self.env.begin(write=False) as txn:
self.keys = [key for key, _ in txn.cursor()]
pickle.dump(self.keys, open(cache_file, "wb"))
self.transform = transform
self.target_transform = target_transform

def __getitem__(self, index):
img, target = None, None
Expand Down Expand Up @@ -60,7 +61,7 @@ class LSUN(data.Dataset):
`LSUN <http://lsun.cs.princeton.edu>`_ dataset.

Args:
db_path (string): Root directory for the database files.
root (string): Root directory for the database files.
classes (string or list): One of {'train', 'val', 'test'} or a list of
categories to load. e,g. ['bedroom_train', 'church_train'].
transform (callable, optional): A function/transform that takes in an PIL image
Expand All @@ -69,13 +70,16 @@ class LSUN(data.Dataset):
target and transforms it.
"""

def __init__(self, db_path, classes='train',
def __init__(self, root, classes='train',
transform=None, target_transform=None):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test']
self.db_path = db_path
self.root = os.path.expanduser(root)
self.transform = transform

This comment was marked as off-topic.

This comment was marked as off-topic.

self.target_transform = target_transform

if type(classes) == str and classes in dset_opts:
if classes == 'test':
classes = [classes]
Expand All @@ -102,7 +106,7 @@ def __init__(self, db_path, classes='train',
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
db_path=db_path + '/' + c + '_lmdb',
root=root + '/' + c + '_lmdb',
transform=transform))

self.indices = []
Expand All @@ -112,7 +116,6 @@ def __init__(self, db_path, classes='train',
self.indices.append(count)

self.length = count
self.target_transform = target_transform

def __getitem__(self, index):
"""
Expand Down Expand Up @@ -146,6 +149,7 @@ def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
fmt_str += ' Classes: {}\n'.format(self.classes)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
Expand Down