Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ torchvision.datasets
All datasets are subclasses of :class:`torch.utils.data.Dataset`
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
For example: ::

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
Expand All @@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively.


.. currentmodule:: torchvision.datasets
.. currentmodule:: torchvision.datasets


MNIST
Expand Down Expand Up @@ -78,14 +78,6 @@ ImageFolder
:members: __getitem__
:special-members:

DatasetFolder
~~~~~~~~~~~~~

.. autoclass:: DatasetFolder
:members: __getitem__
:special-members:



Imagenet-12
~~~~~~~~~~~
Expand Down Expand Up @@ -129,3 +121,4 @@ PhotoTour
.. autoclass:: PhotoTour
:members: __getitem__
:special-members:

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def find_version(*file_paths):
'pillow >= 4.1.1',
'six',
'torch',
'mock',
]

setup(
Expand Down
Binary file removed test/assets/dataset/a/a1.png
Binary file not shown.
Binary file removed test/assets/dataset/a/a2.png
Binary file not shown.
Binary file removed test/assets/dataset/a/a3.png
Binary file not shown.
Binary file removed test/assets/dataset/b/b1.png
Binary file not shown.
Binary file removed test/assets/dataset/b/b2.png
Binary file not shown.
Binary file removed test/assets/dataset/b/b3.png
Binary file not shown.
Binary file removed test/assets/dataset/b/b4.png
Binary file not shown.
59 changes: 0 additions & 59 deletions test/test_folder.py

This file was deleted.

4 changes: 2 additions & 2 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder, DatasetFolder
from .folder import ImageFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
Expand All @@ -11,7 +11,7 @@
from .omniglot import Omniglot

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
Expand Down
142 changes: 52 additions & 90 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch.utils.data as data

from PIL import Image

import os
import os.path

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']


def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
def is_image_file(filename):
"""Checks if a file is an image.

Args:
filename (string): path to a file
Expand All @@ -16,7 +17,7 @@ def has_file_allowed_extension(filename, extensions):
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)


def find_classes(dir):
Expand All @@ -26,7 +27,7 @@ def find_classes(dir):
return classes, class_to_idx


def make_dataset(dir, class_to_idx, extensions):
def make_dataset(dir, class_to_idx):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
Expand All @@ -36,93 +37,14 @@ def make_dataset(dir, class_to_idx, extensions):

for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
if is_image_file(fname):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)

return images


class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::

root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext

root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext

Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.

Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
"""

def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))

self.root = root
self.loader = loader
self.extensions = extensions

self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples

self.transform = transform
self.target_transform = target_transform

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)

return sample, target

def __len__(self):
return len(self.samples)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str


IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']


def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
Expand All @@ -147,7 +69,7 @@ def default_loader(path):
return pil_loader(path)


class ImageFolder(DatasetFolder):
class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::

root/dog/xxx.png
Expand All @@ -171,9 +93,49 @@ class ImageFolder(DatasetFolder):
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""

def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
self.loader = loader

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.imgs)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str