Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ MNIST

.. autoclass:: MNIST

CORe50
~~~~~~

.. autoclass:: CORE50

Fashion-MNIST
~~~~~~~~~~~~~

Expand Down Expand Up @@ -102,7 +107,7 @@ STL10
:special-members:

SVHN
~~~~~
~~~~


.. autoclass:: SVHN
Expand Down
Empty file added test/test_core50.py
Empty file.
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .core50 import CORE50
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder
from .coco import CocoCaptions, CocoDetection
Expand All @@ -9,7 +10,7 @@
from .fakedata import FakeData
from .semeion import SEMEION

__all__ = ('LSUN', 'LSUNClass',
__all__ = ('CORE50', 'LSUN', 'LSUNClass',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'FashionMNIST',
Expand Down
229 changes: 229 additions & 0 deletions torchvision/datasets/core50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

################################################################################
# Date: 22-11-2017 #
# Author: Vincenzo Lomonaco #
# E-mail: vincenzo.lomonaco@unibo.it #
# Website: vincenzolomonaco.com #
################################################################################

""" CORe50 Dataset class """

from __future__ import print_function
import os
import os.path
import torch.utils.data as data
from .utils import download_url, check_integrity
from torchvision.datasets.folder import pil_loader


class CORE50(data.Dataset):
"""`CORE50 <https://vlomonaco.github.io/core50/>`_ Dataset, specifically
designed for Continuous Learning and Robotic Vision applications.
For more information and additional materials visit the official
website `CORE50 <https://vlomonaco.github.io/core50/>`

Args:
root (string): Root directory of the dataset where the ``CORe50``
dataset exists or should be downloaded.
check_integrity (bool, optional): If True check the integrity of the
Dataset before trying to load it.
scenario (string, optional): One of the three scenarios of the CORe50
benchmark ``ni``, ``nc`` or ``nic``.
train (bool, optional): If True, creates the dataset from the training
set, otherwise creates from test set.
img_size (string, optional): One of the two img sizes available among
``128x128`` or ``350x350``.
cumul (bool, optional): If True the cumulative scenario is assumed, the
incremental scenario otherwise. Practically speaking ``cumul=True``
means that for batch=i also batch=0,...i-1 will be added to the
available training data.
run (int, optional): One of the 10 runs (from 0 to 9) in which the
training batch order is changed as in the official benchmark.
batch (int, optional): One of the training incremental batches from 0 to
max-batch - 1. Remember that for the ``ni``, ``nc`` and ``nic`` we
have respectively 8, 9 and 79 incremental batches. If
``train=False`` this parameter will be ignored.
transform (callable, optional): A function/transform that takes in an
PIL image and returns a transformed version. E.g,
``transforms.ToTensor()``
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
download (bool, optional): If true, downloads the dataset from the
internet and puts it in root directory. If dataset is already
downloaded, it is not downloaded again.

Example:

.. code:: python

training_data = datasets.CORE50(
'~/data/core50', transform=transforms.ToTensor(), download=True
)
training_loader = torch.utils.data.DataLoader(
training_data, batch_size=128, shuffle=True, num_workers=4
)
test_data = datasets.CORE50(
'~/data/core50', transform=transforms.ToTensor(), train=False,
download=True
)
test_loader = torch.utils.data.DataLoader(
training_data, batch_size=128, shuffle=True, num_workers=4
)

for batch in training_loader:
imgs, labels = batch
...

This is the simplest way of using the Dataset with the common Train/Test
split. If you want to use the benchmark as in the original CORe50 paper
(that is for continuous learning) you need to play with the parameters
``scenario``, ``cumul``, ``run`` and ``batch`` hence creating a number
of Dataset objects (one for each incremental training batch and one for
the test set).

"""
ntrain_batch = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing, as it's not used (with this name) anywhere in the code, and is referenced as max-batch in the documentation.

If you want to keep this around, I think you should add some asserts in the __init__ checking that batch is within ntrain_batch

'ni': 8,
'nc': 9,
'nic': 79
}
urls = {
'128x128': 'http://bias.csr.unibo.it/maltoni/download/core50/'
'core50_128x128.zip',
'350x350': 'http://bias.csr.unibo.it/maltoni/download/core50/'
'core50_350x350.zip',
'filelists': 'https://vlomonaco.github.io/core50/data/'
'batches_filelists.zip'
}
filenames = {
'128x128': 'core50_128x128.zip',
'350x350': 'core50_350x350.zip',
'filelists': 'batches_filelists.zip'
}
md5s = {
'core50_128x128.zip': '745f3373fed08d69343f1058ee559e13',
'core50_350x350.zip': 'e304258739d6cd4b47e19adfa08e7571',
'batches_filelists.zip': 'e3297508a8998ba0c99a83d6b36bde62'
}

def __init__(self, root, check_integrity=True, scenario='ni', train=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a check_integrity argument in the constructor, and it should by default do the integrity check

img_size='128x128', run=0, batch=7, cumul=True, transform=None,
target_transform=None, download=False):

self.root = os.path.expanduser(root)
self.img_size = img_size
self.scenario = scenario
self.run = run
self.batch = batch
self.transform = transform
self.target_transform = target_transform

# To be filled
self.fpath = None
self.img_paths = []
self.labels = []

# Downloading files if needed
if download:
self.download()

if check_integrity:
print("Making sure CORe50 exists and it's not corrupted...")
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

if cumul:
suffix = 'cum'
else:
suffix = 'inc'

if train:
self.fpath = os.path.join(
scenario.upper() + '_' + suffix, 'run' + str(run),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe something like

'{}_{}'.format(scenario.upper(), suffix), 'run{}'.format(run), 
'train_batch_{:02d}_filelist.txt'.format(batch)

'train_batch_' + str(batch).zfill(2) + '_filelist.txt'
)
else:
# it's the last one, hence the test batch
self.fpath = os.path.join(
scenario.upper() + '_' + suffix, 'run' + str(run),
'test_filelist.txt'
)

# Loading the filelist
path = os.path.join(self.root, self.filenames['filelists'][:-4],
self.fpath)
with open(path, 'r') as f:
for i, line in enumerate(f):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the i is not used. Maybe you could use f.readlines() instead?

if line.strip():
path, label = line.split()
self.labels.append(int(label))
self.img_paths.append(path)

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""

fpath = self.img_paths[index]
target = self.labels[index]
img = pil_loader(
os.path.join(self.root, self.filenames[self.img_size][:-4], fpath)
)

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):

return len(self.labels)

def _check_integrity(self):

# removing files we have not downloaded
if self.img_size == '128x128':
del self.md5s['core50_350x350.zip']
else:
del self.md5s['core50_128x128.zip']

root = self.root
for filename, md5 in self.md5s.items():
fpath = os.path.join(root, filename)
if not check_integrity(fpath, md5):
return False
return True

def download(self):
import zipfile

if self._check_integrity():
print('Files already downloaded and verified')
return

root = self.root

# Downloading the dataset and filelists
for name in (self.img_size, 'filelists'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have recently added functionality to extract files in datasets/utils.py, can you use those instead?

download_url(
self.urls[name], root, self.filenames[name],
self.md5s[self.filenames[name]]
)

# extract file
cwd = os.getcwd()
zip = zipfile.ZipFile(os.path.join(root, self.filenames[name]), "r")
os.chdir(root)
zip.extractall()
zip.close()
os.chdir(cwd)
20 changes: 18 additions & 2 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os.path
import hashlib
import errno
import sys
import time


def check_integrity(fpath, md5):
Expand All @@ -18,6 +20,20 @@ def check_integrity(fpath, md5):
return True


def reporthook(count, block_size, total_size):
global start_time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think start_time is defined anywhere in this file.

Let's remove this functional altogether for now, and maybe send another PR adding it if necessary?

if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = min(int(count * block_size * 100 / total_size), 100)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()


def download_url(url, root, filename, md5):
from six.moves import urllib

Expand All @@ -38,10 +54,10 @@ def download_url(url, root, filename, md5):
else:
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(url, fpath)
urllib.request.urlretrieve(url, fpath, reporthook)
except:
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(url, fpath)
urllib.request.urlretrieve(url, fpath, reporthook)