Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ def set_image_backend(backend):
"""
Specifies the package used to load images.

Options are 'PIL' and 'accimage'. The :mod:`accimage` package uses the
Intel IPP library. It is generally faster than PIL, but does not support as
many operations.

Args:
backend (string): name of the image backend
backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
The :mod:`accimage` package uses the Intel IPP library. It is
generally faster than PIL, but does not support as many operations.
"""
global _image_backend
if backend not in ['PIL', 'accimage']:
Expand Down
23 changes: 23 additions & 0 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@


class CIFAR10(data.Dataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

"""
base_folder = 'cifar-10-batches-py'
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
Expand Down Expand Up @@ -86,6 +102,13 @@ def __init__(self, root, train=True,
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
Expand Down
60 changes: 60 additions & 0 deletions torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,43 @@


class CocoCaptions(data.Dataset):
"""`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.

Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.

Example:

.. code:: python

import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

Output: ::

Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']

"""
def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO
self.root = root
Expand All @@ -15,6 +51,13 @@ def __init__(self, root, annFile, transform=None, target_transform=None):
self.target_transform = target_transform

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
Expand All @@ -37,6 +80,16 @@ def __len__(self):


class CocoDetection(data.Dataset):
"""`MS Coco Captions <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.

Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO
Expand All @@ -47,6 +100,13 @@ def __init__(self, root, annFile, transform=None, target_transform=None):
self.target_transform = target_transform

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
Expand Down
30 changes: 30 additions & 0 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,29 @@ def default_loader(path):


class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.

Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""

def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
Expand All @@ -81,6 +104,13 @@ def __init__(self, root, transform=None, target_transform=None,
self.loader = loader

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
Expand Down
20 changes: 17 additions & 3 deletions torchvision/datasets/lsun.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class LSUNClass(data.Dataset):

def __init__(self, db_path, transform=None, target_transform=None):
import lmdb
self.db_path = db_path
Expand Down Expand Up @@ -58,8 +57,16 @@ def __repr__(self):

class LSUN(data.Dataset):
"""
db_path = root directory for the database files
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
`LSUN <http://lsun.cs.princeton.edu>`_ dataset.

Args:
db_path (string): Root directory for the database files.
classes (string or list): One of {'train', 'val', 'test'} or a list of
categories to load. e,g. ['bedroom_train', 'church_train'].
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

def __init__(self, db_path, classes='train',
Expand Down Expand Up @@ -108,6 +115,13 @@ def __init__(self, db_path, classes='train',
self.target_transform = target_transform

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: Tuple (image, target) where target is the index of the target category.
"""
target = 0
sub = 0
for ind in self.indices:
Expand Down
25 changes: 23 additions & 2 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,25 @@
import os.path
import errno
import torch
import json
import codecs
import numpy as np


class MNIST(data.Dataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
Expand Down Expand Up @@ -42,6 +55,13 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
Expand Down Expand Up @@ -70,6 +90,7 @@ def _check_exists(self):
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))

def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip

Expand Down
20 changes: 20 additions & 0 deletions torchvision/datasets/phototour.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@


class PhotoTour(data.Dataset):
"""`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.


Args:
root (string): Root directory where images are.
name (string): Name of the dataset to load.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

"""
urls = {
'notredame': [
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip',
Expand Down Expand Up @@ -59,6 +72,13 @@ def __init__(self, root, name, train=True, transform=None, download=False):
self.data, self.labels, self.matches = torch.load(self.data_file)

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (data1, data2, matches)
"""
if self.train:
data = self.data[index]
if self.transform is not None:
Expand Down
23 changes: 23 additions & 0 deletions torchvision/datasets/stl10.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@


class STL10(CIFAR10):
"""`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``stl10_binary`` exists.
split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
Accordingly dataset is selected.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

"""
base_folder = 'stl10_binary'
url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
filename = "stl10_binary.tar.gz"
Expand Down Expand Up @@ -67,6 +83,13 @@ def __init__(self, root, split='train',
self.classes = f.read().splitlines()

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.labels is not None:
img, target = self.data[index], int(self.labels[index])
else:
Expand Down
25 changes: 23 additions & 2 deletions torchvision/datasets/svhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,27 @@
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
from .utils import download_url, check_integrity


class SVHN(data.Dataset):
"""`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``SVHN`` exists.
split (string): One of {'train', 'test', 'extra'}.
Accordingly dataset is selected. 'extra' is Extra training set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

"""
url = ""
filename = ""
file_md5 = ""
Expand Down Expand Up @@ -56,6 +70,13 @@ def __init__(self, root, split='train',
self.data = np.transpose(self.data, (3, 2, 0, 1))

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.labels[index]

# doing this so that it is consistent with all other datasets
Expand Down
Loading