Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OMNIGLOT Dataset #46

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The following dataset loaders are available:
- `ImageFolder <#imagefolder>`__
- `Imagenet-12 <#imagenet-12>`__
- `CIFAR10 and CIFAR100 <#cifar>`__

- OMNIGLOT
Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass
from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded
(python multiprocessing) using standard torch.utils.data.DataLoader.
Expand Down Expand Up @@ -187,6 +187,16 @@ here <https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#downloa
`Here is an
example <https://github.com/pytorch/examples/blob/27e2a46c1d1505324032b1d94fc6ce24d5b67e97/imagenet/main.py#L48-L62>`__.

OMNIGLOT
~~~~~~~~

`dset.OMNIGLOT(root_dir, [transform=None, target_transform=None])`

The dataset is composed of pairs: ``(Filename,Category idx)``. Each caty"egory corresponds to one character in one alphabet. Matching between classes indexes and real classes can be accessed through: `dataset.idx_classes`
The dataset can be used with ``transform=transforms.FilenameToPILImage`` to obtain pairs of (PIL Image,Category_idx)

From: `Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.`

Models
======

Expand Down
13 changes: 13 additions & 0 deletions test/test_omniglot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms

print('Omniglot')
a = dset.OMNIGLOT("../data", download=True,transform=transforms.Compose([transforms.FilenameToPILImage(),transforms.ToTensor()]))

print(a.idx_classes)
print(a[3])
# print('\n\nCifar 100')
# a = dset.CIFAR100(root="abc/def/ghi", download=True)

# print(a[3])
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .mnist import MNIST
from .omniglot import OMNIGLOT

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100',
'MNIST')
'MNIST','OMNIGLOT')
115 changes: 115 additions & 0 deletions torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import torch
import json
import codecs
import numpy as np
from PIL import Image

class OMNIGLOT(data.Dataset):
urls = [
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'

'''
The items are (filename,category). The index of all the categories can be found in self.idx_classes

Args:

- root: the directory where the dataset will be stored
- transform: how to transform the input
- target_transform: how to transform the target
- download: need to download the dataset
'''
def __init__(self, root, transform=None, target_transform=None, download=False):

This comment was marked as off-topic.

self.root = root
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if not self._check_exists():
raise RuntimeError('Dataset not found.'
+ ' You can use download=True to download it')

self.all_items=find_classes(os.path.join(self.root, self.processed_folder))

This comment was marked as off-topic.

self.idx_classes=index_classes(self.all_items)

def __getitem__(self, index):
filename=self.all_items[index][0]
img=str.join('/',[self.all_items[index][2],filename])

target=self.idx_classes[self.all_items[index][1]]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)

return img,target

def __len__(self):
return len(self.all_items)

def _check_exists(self):

This comment was marked as off-topic.

return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))

def download(self):
from six.moves import urllib
import zipfile

if self._check_exists():
return

# download files
try:

This comment was marked as off-topic.

os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise

for url in self.urls:
print('== Downloading ' + url)
data = urllib.request.urlopen(url)

This comment was marked as off-topic.

filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
file_processed = os.path.join(self.root, self.processed_folder)
print("== Unzip from "+file_path+" to "+file_processed)
zip_ref = zipfile.ZipFile(file_path, 'r')
zip_ref.extractall(file_processed)
zip_ref.close()
print("Download finished.")

def find_classes(root_dir):
retour=[]

This comment was marked as off-topic.

for (root,dirs,files) in os.walk(root_dir):
for f in files:
if (f.endswith("png")):
r=root.split('/')
lr=len(r)
retour.append((f,r[lr-2]+"/"+r[lr-1],root))

This comment was marked as off-topic.

print("== Found %d items "%len(retour))

This comment was marked as off-topic.

return retour

def index_classes(items):
idx={}
for i in items:
if (not i[1] in idx):

This comment was marked as off-topic.

idx[i[1]]=len(idx)
print("== Found %d classes"% len(idx))

This comment was marked as off-topic.

return idx
7 changes: 7 additions & 0 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def __call__(self, img):
img = t(img)
return img

class FilenameToPILImage(object):

This comment was marked as off-topic.

"""
Load a PIL RGB Image from a filename.
"""
def __call__(self,filename):
img=Image.open(filename).convert('RGB')
return img

class ToTensor(object):
"""Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range
Expand Down