From 74c4b7a8b28123ee8cdac8939b2f60fcb8935f44 Mon Sep 17 00:00:00 2001 From: Jaesuny Date: Wed, 21 Mar 2018 15:33:23 +0900 Subject: [PATCH 1/4] Fix uninitialized instance variables --- torchvision/datasets/lsun.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index e165b957c67..870afe00d05 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -76,6 +76,9 @@ def __init__(self, db_path, classes='train', 'living_room', 'restaurant', 'tower'] dset_opts = ['train', 'val', 'test'] self.db_path = db_path + self.transform = transform + self.target_transform = target_transform + if type(classes) == str and classes in dset_opts: if classes == 'test': classes = [classes] @@ -112,7 +115,6 @@ def __init__(self, db_path, classes='train', self.indices.append(count) self.length = count - self.target_transform = target_transform def __getitem__(self, index): """ @@ -145,7 +147,7 @@ def __len__(self): def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) + fmt_str += ' Root Location: {}\n'.format(self.db_path) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' From 6069179cf1b605e012b9584af9973928749bfa4a Mon Sep 17 00:00:00 2001 From: Jaesuny Date: Wed, 21 Mar 2018 15:40:12 +0900 Subject: [PATCH 2/4] Maintain consistency with other dataset classes --- torchvision/datasets/lsun.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index 870afe00d05..2a9fe7b3858 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -12,22 +12,23 @@ class LSUNClass(data.Dataset): - def __init__(self, db_path, transform=None, target_transform=None): + def __init__(self, root, transform=None, target_transform=None): import lmdb - self.db_path = db_path - self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False, + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + + self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin(write=False) as txn: self.length = txn.stat()['entries'] - cache_file = '_cache_' + db_path.replace('/', '_') + cache_file = '_cache_' + root.replace('/', '_') if os.path.isfile(cache_file): self.keys = pickle.load(open(cache_file, "rb")) else: with self.env.begin(write=False) as txn: self.keys = [key for key, _ in txn.cursor()] pickle.dump(self.keys, open(cache_file, "wb")) - self.transform = transform - self.target_transform = target_transform def __getitem__(self, index): img, target = None, None @@ -60,7 +61,7 @@ class LSUN(data.Dataset): `LSUN `_ dataset. Args: - db_path (string): Root directory for the database files. + root (string): Root directory for the database files. classes (string or list): One of {'train', 'val', 'test'} or a list of categories to load. e,g. ['bedroom_train', 'church_train']. transform (callable, optional): A function/transform that takes in an PIL image @@ -69,15 +70,16 @@ class LSUN(data.Dataset): target and transforms it. """ - def __init__(self, db_path, classes='train', + def __init__(self, root, classes='train', transform=None, target_transform=None): categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower'] dset_opts = ['train', 'val', 'test'] - self.db_path = db_path + self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform + self.classes = classes if type(classes) == str and classes in dset_opts: if classes == 'test': @@ -105,7 +107,7 @@ def __init__(self, db_path, classes='train', self.dbs = [] for c in self.classes: self.dbs.append(LSUNClass( - db_path=db_path + '/' + c + '_lmdb', + root=root + '/' + c + '_lmdb', transform=transform)) self.indices = [] @@ -147,7 +149,8 @@ def __len__(self): def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.db_path) + fmt_str += ' Root Location: {}\n'.format(self.root) + fmt_str += ' Classes: {}\n'.format(self.classes) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' From 4e852d7991ab50c53ecc6f4bee4d656f92e1080c Mon Sep 17 00:00:00 2001 From: Jason Park Date: Wed, 21 Mar 2018 22:33:55 +0900 Subject: [PATCH 3/4] Fix double assignment --- torchvision/datasets/lsun.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index 2a9fe7b3858..1727a94a932 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -101,7 +101,6 @@ def __init__(self, root, classes='train', 'Options are: ' + str(dset_opts))) else: raise(ValueError('Unknown option for classes')) - self.classes = classes # for each class, create an LSUNClassDataset self.dbs = [] From 5fcea0cde8af4beb019141b50511312dc99eaf6f Mon Sep 17 00:00:00 2001 From: Jason Park Date: Thu, 22 Mar 2018 02:04:16 +0900 Subject: [PATCH 4/4] Fix initialization of self.classes --- torchvision/datasets/lsun.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index 1727a94a932..731461e8a8d 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -79,7 +79,6 @@ def __init__(self, root, classes='train', self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform - self.classes = classes if type(classes) == str and classes in dset_opts: if classes == 'test': @@ -101,6 +100,7 @@ def __init__(self, root, classes='train', 'Options are: ' + str(dset_opts))) else: raise(ValueError('Unknown option for classes')) + self.classes = classes # for each class, create an LSUNClassDataset self.dbs = []