From 1ad594bee2e2a1da9c4b9d6ab1846e325289f5ef Mon Sep 17 00:00:00 2001 From: neoglez Date: Sat, 28 Oct 2017 14:06:31 +0200 Subject: [PATCH 1/4] First commit for semeion dataset SEMEION Handwritten Digits Data Set http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit --- torchvision/datasets/semeion.py | 157 ++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 torchvision/datasets/semeion.py diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py new file mode 100644 index 00000000000..86f5519d713 --- /dev/null +++ b/torchvision/datasets/semeion.py @@ -0,0 +1,157 @@ +from __future__ import print_function +from PIL import Image +import os +import os.path +import errno +import numpy as np +import sys +if sys.version_info[0] == 2: + import cPickle as pickle +else: + import pickle + +import torch.utils.data as data +from .utils import download_url, check_integrity + + +class SEMEION(data.Dataset): + """`SEMEION `_ Dataset. + Args: + root (string): Root directory of dataset where directory + ``semeion.py`` exists. + train (bool, optional): If True, creates dataset from training set, otherwise + creates from test set. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + base_folder = 'semeion-py' + url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data" + filename = "semeion.data" + tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ] + + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'], + ] + + def __init__(self, root, train=True, + transform=None, target_transform=None, + download=False): + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + self.train = train # training set or test set + + if download: + self.download() + + #if not self._check_integrity(): + #raise RuntimeError('Dataset not found or corrupted.' + + #' You can use download=True to download it') + + # now load the picked numpy arrays + if self.train: + self.train_data = [] + self.train_labels = [] + for fentry in self.train_list: + f = fentry[0] + file = os.path.join(self.root, self.base_folder, f) + fo = open(file, 'rb') + if sys.version_info[0] == 2: + entry = pickle.load(fo) + else: + entry = pickle.load(fo, encoding='latin1') + self.train_data.append(entry['data']) + if 'labels' in entry: + self.train_labels += entry['labels'] + else: + self.train_labels += entry['fine_labels'] + fo.close() + + self.train_data = np.concatenate(self.train_data) + self.train_data = self.train_data.reshape((50000, 3, 32, 32)) + self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC + else: + f = self.test_list[0][0] + file = os.path.join(self.root, self.base_folder, f) + fo = open(file, 'rb') + if sys.version_info[0] == 2: + entry = pickle.load(fo) + else: + entry = pickle.load(fo, encoding='latin1') + self.test_data = entry['data'] + if 'labels' in entry: + self.test_labels = entry['labels'] + else: + self.test_labels = entry['fine_labels'] + fo.close() + self.test_data = self.test_data.reshape((10000, 3, 32, 32)) + self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target class. + """ + if self.train: + img, target = self.train_data[index], self.train_labels[index] + else: + img, target = self.test_data[index], self.test_labels[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + if self.train: + return len(self.train_data) + else: + return len(self.test_data) + + def _check_integrity(self): + root = self.root + for fentry in (self.train_list + self.test_list): + filename, md5 = fentry[0], fentry[1] + fpath = os.path.join(root, self.base_folder, filename) + if not check_integrity(fpath, md5): + return False + return True + + def download(self): + import tarfile + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + root = self.root + download_url(self.url, root, self.filename, self.tgz_md5) + + # extract file + cwd = os.getcwd() + tar = tarfile.open(os.path.join(root, self.filename), "r:gz") + os.chdir(root) + tar.extractall() + tar.close() + os.chdir(cwd) From 1a69cb2ad93ce816a696562b24e0f4e7f3f20eb2 Mon Sep 17 00:00:00 2001 From: neoglez Date: Mon, 6 Nov 2017 14:43:38 +0100 Subject: [PATCH 2/4] SEMEION Class --- torchvision/datasets/__init__.py | 3 +- torchvision/datasets/semeion.py | 144 ++++++++++++------------------- 2 files changed, 59 insertions(+), 88 deletions(-) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 1edbc49d88f..9fab55190cc 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -7,9 +7,10 @@ from .svhn import SVHN from .phototour import PhotoTour from .fakedata import FakeData +from .semeion import SEMEION __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'FashionMNIST', - 'MNIST', 'STL10', 'SVHN', 'PhotoTour') + 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION') diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index 86f5519d713..c12b4a8e7e1 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -19,8 +19,6 @@ class SEMEION(data.Dataset): Args: root (string): Root directory of dataset where directory ``semeion.py`` exists. - train (bool, optional): If True, creates dataset from training set, otherwise - creates from test set. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the @@ -29,75 +27,60 @@ class SEMEION(data.Dataset): puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ - base_folder = 'semeion-py' url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data" filename = "semeion.data" - tgz_md5 = 'c58f30108f718f92721af3b95e74349a' - train_list = [ - ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], - ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], - ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], - ['data_batch_4', '634d18415352ddfa80567beed471001a'], - ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], - ] - - test_list = [ - ['test_batch', '40351d587109b95175f43aff81a1287e'], - ] - - def __init__(self, root, train=True, - transform=None, target_transform=None, - download=False): + md5_checksum = 'cb545d371d2ce14ec121470795a77432' + + + def __init__(self, root, transform=None, target_transform=None, + download=True): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform - self.train = train # training set or test set if download: self.download() - #if not self._check_integrity(): - #raise RuntimeError('Dataset not found or corrupted.' + - #' You can use download=True to download it') - - # now load the picked numpy arrays - if self.train: - self.train_data = [] - self.train_labels = [] - for fentry in self.train_list: - f = fentry[0] - file = os.path.join(self.root, self.base_folder, f) - fo = open(file, 'rb') - if sys.version_info[0] == 2: - entry = pickle.load(fo) - else: - entry = pickle.load(fo, encoding='latin1') - self.train_data.append(entry['data']) - if 'labels' in entry: - self.train_labels += entry['labels'] - else: - self.train_labels += entry['fine_labels'] - fo.close() - - self.train_data = np.concatenate(self.train_data) - self.train_data = self.train_data.reshape((50000, 3, 32, 32)) - self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC - else: - f = self.test_list[0][0] - file = os.path.join(self.root, self.base_folder, f) - fo = open(file, 'rb') - if sys.version_info[0] == 2: - entry = pickle.load(fo) - else: - entry = pickle.load(fo, encoding='latin1') - self.test_data = entry['data'] - if 'labels' in entry: - self.test_labels = entry['labels'] - else: - self.test_labels = entry['fine_labels'] - fo.close() - self.test_data = self.test_data.reshape((10000, 3, 32, 32)) - self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + self.data = [] + self.labels = [] + fp = os.path.join(root, self.filename) + file = open(fp, 'r') + data = file.read() + file.close() + dataSplitted = data.split("\n")[:-1] + datasetLength = len(dataSplitted) + i = 0 + while i < datasetLength: + # Get the 'i-th' row + strings = dataSplitted[i] + + # Split row into numbers(string), and avoid blank at the end + stringsSplitted = (strings[:-1]).split(" ") + + # Get data (which ends at column 256th), then in a numpy array. + rawData = stringsSplitted[:256] + dataFloat = [float(j) for j in rawData] + img = np.array(dataFloat[:16]) + j = 16 + k = 0 + while j < len(dataFloat): + temp = np.array(dataFloat[k:j]) + img = np.vstack((img,temp)) + + k = j + j += 16 + + self.data.append(img) + + # Get label and convert it into numbers, then in a numpy array. + labelString = stringsSplitted[256:] + labelInt = [int(i) for i in labelString] + self.labels.append(np.array(labelInt)) + i += 1 def __getitem__(self, index): """ @@ -106,14 +89,14 @@ def __getitem__(self, index): Returns: tuple: (image, target) where target is index of the target class. """ - if self.train: - img, target = self.train_data[index], self.train_labels[index] - else: - img, target = self.test_data[index], self.test_labels[index] + img, target = self.data[index], self.labels[index] # doing this so that it is consistent with all other datasets # to return a PIL Image - img = Image.fromarray(img) + # convert value to 8 bit unsigned integer + # color (white #255) the pixels + img = img.astype('uint8')*255 + img = Image.fromarray(img, mode='L') if self.transform is not None: img = self.transform(img) @@ -124,34 +107,21 @@ def __getitem__(self, index): return img, target def __len__(self): - if self.train: - return len(self.train_data) - else: - return len(self.test_data) + return len(self.data) def _check_integrity(self): root = self.root - for fentry in (self.train_list + self.test_list): - filename, md5 = fentry[0], fentry[1] - fpath = os.path.join(root, self.base_folder, filename) - if not check_integrity(fpath, md5): - return False + fpath = os.path.join(root, self.filename) + if not check_integrity(fpath, self.md5_checksum): + return False return True def download(self): - import tarfile - if self._check_integrity(): print('Files already downloaded and verified') return root = self.root - download_url(self.url, root, self.filename, self.tgz_md5) - - # extract file - cwd = os.getcwd() - tar = tarfile.open(os.path.join(root, self.filename), "r:gz") - os.chdir(root) - tar.extractall() - tar.close() - os.chdir(cwd) + download_url(self.url, root, self.filename, self.md5_checksum) + + From 42b5869684f4f2e47c993397cfaa01cbd02543b7 Mon Sep 17 00:00:00 2001 From: neoglez Date: Mon, 6 Nov 2017 21:08:00 +0100 Subject: [PATCH 3/4] Fix Lint errors --- torchvision/datasets/semeion.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index c12b4a8e7e1..9d97388c523 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -32,8 +32,7 @@ class SEMEION(data.Dataset): md5_checksum = 'cb545d371d2ce14ec121470795a77432' - def __init__(self, root, transform=None, target_transform=None, - download=True): + def __init__(self, root, transform=None, target_transform=None, download=True): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform @@ -56,11 +55,11 @@ def __init__(self, root, transform=None, target_transform=None, i = 0 while i < datasetLength: # Get the 'i-th' row - strings = dataSplitted[i] - + strings = dataSplitted[i] + # Split row into numbers(string), and avoid blank at the end stringsSplitted = (strings[:-1]).split(" ") - + # Get data (which ends at column 256th), then in a numpy array. rawData = stringsSplitted[:256] dataFloat = [float(j) for j in rawData] @@ -69,16 +68,16 @@ def __init__(self, root, transform=None, target_transform=None, k = 0 while j < len(dataFloat): temp = np.array(dataFloat[k:j]) - img = np.vstack((img,temp)) - - k = j + img = np.vstack((img, temp)) + + k = j j += 16 - + self.data.append(img) - + # Get label and convert it into numbers, then in a numpy array. labelString = stringsSplitted[256:] - labelInt = [int(i) for i in labelString] + labelInt = [int(index) for index in labelString] self.labels.append(np.array(labelInt)) i += 1 @@ -95,7 +94,7 @@ def __getitem__(self, index): # to return a PIL Image # convert value to 8 bit unsigned integer # color (white #255) the pixels - img = img.astype('uint8')*255 + img = img.astype('uint8') * 255 img = Image.fromarray(img, mode='L') if self.transform is not None: @@ -123,5 +122,3 @@ def download(self): root = self.root download_url(self.url, root, self.filename, self.md5_checksum) - - From e918a9ef0b82764321e910d073ce895daf4171e3 Mon Sep 17 00:00:00 2001 From: neoglez Date: Mon, 6 Nov 2017 21:52:36 +0100 Subject: [PATCH 4/4] More linting errors --- torchvision/datasets/semeion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index 9d97388c523..07592b64bae 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -31,7 +31,6 @@ class SEMEION(data.Dataset): filename = "semeion.data" md5_checksum = 'cb545d371d2ce14ec121470795a77432' - def __init__(self, root, transform=None, target_transform=None, download=True): self.root = os.path.expanduser(root) self.transform = transform @@ -55,7 +54,7 @@ def __init__(self, root, transform=None, target_transform=None, download=True): i = 0 while i < datasetLength: # Get the 'i-th' row - strings = dataSplitted[i] + strings = dataSplitted[i] # Split row into numbers(string), and avoid blank at the end stringsSplitted = (strings[:-1]).split(" ")