diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 230f9ae4627..3f61ae7cbd8 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -4,11 +4,11 @@ torchvision.datasets All datasets are subclasses of :class:`torch.utils.data.Dataset` i.e, they have ``__getitem__`` and ``__len__`` methods implemented. Hence, they can all be passed to a :class:`torch.utils.data.DataLoader` -which can load multiple samples parallelly using ``torch.multiprocessing`` workers. +which can load multiple samples parallelly using ``torch.multiprocessing`` workers. For example: :: - + imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/') - data_loader = torch.utils.data.DataLoader(imagenet_data, + data_loader = torch.utils.data.DataLoader(imagenet_data, batch_size=4, shuffle=True, num_workers=args.nThreads) @@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments: ``transform`` and ``target_transform`` to transform the input and target respectively. -.. currentmodule:: torchvision.datasets +.. currentmodule:: torchvision.datasets MNIST @@ -78,14 +78,6 @@ ImageFolder :members: __getitem__ :special-members: -DatasetFolder -~~~~~~~~~~~~~ - -.. autoclass:: DatasetFolder - :members: __getitem__ - :special-members: - - Imagenet-12 ~~~~~~~~~~~ @@ -129,3 +121,4 @@ PhotoTour .. autoclass:: PhotoTour :members: __getitem__ :special-members: + diff --git a/setup.py b/setup.py index 59805be6268..0f46586deec 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,6 @@ def find_version(*file_paths): 'pillow >= 4.1.1', 'six', 'torch', - 'mock', ] setup( diff --git a/test/assets/dataset/a/a1.png b/test/assets/dataset/a/a1.png deleted file mode 100644 index 52d23cd374e..00000000000 Binary files a/test/assets/dataset/a/a1.png and /dev/null differ diff --git a/test/assets/dataset/a/a2.png b/test/assets/dataset/a/a2.png deleted file mode 100644 index b6ff2059d4c..00000000000 Binary files a/test/assets/dataset/a/a2.png and /dev/null differ diff --git a/test/assets/dataset/a/a3.png b/test/assets/dataset/a/a3.png deleted file mode 100644 index 86cc0511536..00000000000 Binary files a/test/assets/dataset/a/a3.png and /dev/null differ diff --git a/test/assets/dataset/b/b1.png b/test/assets/dataset/b/b1.png deleted file mode 100644 index d542ba0ff9d..00000000000 Binary files a/test/assets/dataset/b/b1.png and /dev/null differ diff --git a/test/assets/dataset/b/b2.png b/test/assets/dataset/b/b2.png deleted file mode 100644 index 36edbf49912..00000000000 Binary files a/test/assets/dataset/b/b2.png and /dev/null differ diff --git a/test/assets/dataset/b/b3.png b/test/assets/dataset/b/b3.png deleted file mode 100644 index b44b8719b3d..00000000000 Binary files a/test/assets/dataset/b/b3.png and /dev/null differ diff --git a/test/assets/dataset/b/b4.png b/test/assets/dataset/b/b4.png deleted file mode 100644 index 11e24914ffa..00000000000 Binary files a/test/assets/dataset/b/b4.png and /dev/null differ diff --git a/test/test_folder.py b/test/test_folder.py deleted file mode 100644 index 7abd815b060..00000000000 --- a/test/test_folder.py +++ /dev/null @@ -1,59 +0,0 @@ -import unittest -try: - from unittest.mock import Mock -except ImportError as e: - from mock import Mock - -import os - -from torchvision.datasets import ImageFolder - - -class Tester(unittest.TestCase): - root = 'test/assets/dataset/' - classes = ['a', 'b'] - class_a_images = [os.path.join('test/assets/dataset/a/', path) for path in ['a1.png', 'a2.png', 'a3.png']] - class_b_images = [os.path.join('test/assets/dataset/b/', path) for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']] - - def test_image_folder(self): - dataset = ImageFolder(Tester.root, loader=lambda x: x) - self.assertEqual(sorted(Tester.classes), sorted(dataset.classes)) - for cls in Tester.classes: - self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]]) - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images] - imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images] - imgs = sorted(imgs_a + imgs_b) - self.assertEqual(imgs, dataset.imgs) - - outputs = sorted([dataset[i] for i in range(len(dataset))]) - self.assertEqual(imgs, outputs) - - def test_transform(self): - return_value = 'test/assets/dataset/a/a1.png' - transform = Mock(return_value=return_value) - dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform) - outputs = [dataset[i][0] for i in range(len(dataset))] - self.assertEqual([return_value] * len(outputs), outputs) - - imgs = sorted(Tester.class_a_images + Tester.class_b_images) - args = [call[0][0] for call in transform.call_args_list] - self.assertEqual(imgs, sorted(args)) - - def test_target_transform(self): - return_value = 1 - target_transform = Mock(return_value=return_value) - dataset = ImageFolder(Tester.root, loader=lambda x: x, target_transform=target_transform) - outputs = [dataset[i][1] for i in range(len(dataset))] - self.assertEqual([return_value] * len(outputs), outputs) - - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - targets = sorted([class_a_idx] * len(Tester.class_a_images) + - [class_b_idx] * len(Tester.class_b_images)) - args = [call[0][0] for call in target_transform.call_args_list] - self.assertEqual(targets, sorted(args)) - -if __name__ == '__main__': - unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e2d2801216a..1cb604a79f6 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,5 +1,5 @@ from .lsun import LSUN, LSUNClass -from .folder import ImageFolder, DatasetFolder +from .folder import ImageFolder from .coco import CocoCaptions, CocoDetection from .cifar import CIFAR10, CIFAR100 from .stl10 import STL10 @@ -11,7 +11,7 @@ from .omniglot import Omniglot __all__ = ('LSUN', 'LSUNClass', - 'ImageFolder', 'DatasetFolder', 'FakeData', + 'ImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index a6bb5742bc5..6fef78ee5ad 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,13 +1,14 @@ import torch.utils.data as data from PIL import Image - import os import os.path +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] + -def has_file_allowed_extension(filename, extensions): - """Checks if a file is an allowed extension. +def is_image_file(filename): + """Checks if a file is an image. Args: filename (string): path to a file @@ -16,7 +17,7 @@ def has_file_allowed_extension(filename, extensions): bool: True if the filename ends with a known image extension """ filename_lower = filename.lower() - return any(filename_lower.endswith(ext) for ext in extensions) + return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) def find_classes(dir): @@ -26,7 +27,7 @@ def find_classes(dir): return classes, class_to_idx -def make_dataset(dir, class_to_idx, extensions): +def make_dataset(dir, class_to_idx): images = [] dir = os.path.expanduser(dir) for target in sorted(os.listdir(dir)): @@ -36,7 +37,7 @@ def make_dataset(dir, class_to_idx, extensions): for root, _, fnames in sorted(os.walk(d)): for fname in sorted(fnames): - if has_file_allowed_extension(fname, extensions): + if is_image_file(fname): path = os.path.join(root, fname) item = (path, class_to_idx[target]) images.append(item) @@ -44,85 +45,6 @@ def make_dataset(dir, class_to_idx, extensions): return images -class DatasetFolder(data.Dataset): - """A generic data loader where the samples are arranged in this way: :: - - root/class_x/xxx.ext - root/class_x/xxy.ext - root/class_x/xxz.ext - - root/class_y/123.ext - root/class_y/nsdf3.ext - root/class_y/asd932_.ext - - Args: - root (string): Root directory path. - loader (callable): A function to load a sample given its path. - extensions (list[string]): A list of allowed extensions. - transform (callable, optional): A function/transform that takes in - a sample and returns a transformed version. - E.g, ``transforms.RandomCrop`` for images. - target_transform (callable, optional): A function/transform that takes - in the target and transforms it. - - Attributes: - classes (list): List of the class names. - class_to_idx (dict): Dict with items (class_name, class_index). - samples (list): List of (sample path, class_index) tuples - """ - - def __init__(self, root, loader, extensions, transform=None, target_transform=None): - classes, class_to_idx = find_classes(root) - samples = make_dataset(root, class_to_idx, extensions) - if len(samples) == 0: - raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" - "Supported extensions are: " + ",".join(extensions))) - - self.root = root - self.loader = loader - self.extensions = extensions - - self.classes = classes - self.class_to_idx = class_to_idx - self.samples = samples - - self.transform = transform - self.target_transform = target_transform - - def __getitem__(self, index): - """ - Args: - index (int): Index - - Returns: - tuple: (sample, target) where target is class_index of the target class. - """ - path, target = self.samples[index] - sample = self.loader(path) - if self.transform is not None: - sample = self.transform(sample) - if self.target_transform is not None: - target = self.target_transform(target) - - return sample, target - - def __len__(self): - return len(self.samples) - - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str - - -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] - - def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: @@ -147,7 +69,7 @@ def default_loader(path): return pil_loader(path) -class ImageFolder(DatasetFolder): +class ImageFolder(data.Dataset): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png @@ -171,9 +93,49 @@ class ImageFolder(DatasetFolder): class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ + def __init__(self, root, transform=None, target_transform=None, loader=default_loader): - super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, - transform=transform, - target_transform=target_transform) - self.imgs = self.samples + classes, class_to_idx = find_classes(root) + imgs = make_dataset(root, class_to_idx) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.classes = classes + self.class_to_idx = class_to_idx + self.transform = transform + self.target_transform = target_transform + self.loader = loader + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + path, target = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.imgs) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str