-
Notifications
You must be signed in to change notification settings - Fork 3
/
cifar.py
115 lines (84 loc) · 4.33 KB
/
cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import numpy as np
import vipy
import pickle
CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
class CIFAR10():
"""vipy.data.cifar.CIFAR10 class
>>> D = vipy.data.cifar.CIFAR10('/path/to/outdir')
>>> d = D.trainset()
>>> im = d[0].mindim(512).show()
"""
def __init__(self, outdir, name='cifar10', url=CIFAR10_URL, md5=CIFAR10_MD5):
self._datadir = vipy.util.remkdir(outdir)
self._subdir = 'cifar-10-batches-py'
if not os.path.exists(os.path.join(outdir, self._subdir, 'data_batch_1')):
print('[vipy.data.cifar10]: downloading CIFAR-10 to "%s"' % self._datadir)
vipy.downloader.download_and_unpack(url, self._datadir, md5=md5)
self._train_archives = [os.path.join(outdir, self._subdir, 'data_batch_%d' % k) for k in range(1,6)]
self._test_archives = [os.path.join(self._datadir, self._subdir, 'test_batch')]
f = os.path.join(self._datadir, self._subdir, 'batches.meta')
assert os.path.exists(f)
with open(f, 'rb') as fo:
d = pickle.load(fo, encoding='bytes')
self._classes = [x.decode("utf-8") for x in d[b'label_names']]
self._trainset()
self._testset()
self._name = name
def __repr__(self):
return '<vipy.data.%s: %s>' % (self._name, self._datadir)
def classes(self):
return self._classes
def trainset(self):
return vipy.dataset.Dataset([vipy.image.ImageCategory(category=self._classes[y], array=x, colorspace='rgb') for (x,y) in zip(self._trainset, self._trainlabels)], '%s_train' % self._name)
def testset(self):
return vipy.dataset.Dataset([vipy.image.ImageCategory(category=self._classes[y], array=x, colorspace='rgb') for (x,y) in zip(self._testset, self._testlabels)], '%s_test' % self._name)
def _trainset(self, labelkey=b'labels'):
(data, labels) = ([], [])
for f in self._train_archives:
assert os.path.exists(f)
with open(f, 'rb') as fo:
d = pickle.load(fo, encoding='bytes')
data.append(d[b'data'])
labels.append(d[labelkey])
self._trainset = np.vstack(data)
self._trainset = [np.transpose(x.reshape(3, 32, 32), axes=(1,2,0)) for x in self._trainset]
self._trainlabels = [l for lbl in labels for l in lbl]
return self
def _testset(self, labelkey=b'labels'):
(data, labels) = ([], [])
for f in self._test_archives:
assert os.path.exists(f)
with open(f, 'rb') as fo:
d = pickle.load(fo, encoding='bytes')
data.append(d[b'data'])
labels.append(d[labelkey])
self._testset = np.vstack(data)
self._testset = [np.transpose(x.reshape(3, 32, 32), axes=(1,2,0)) for x in self._testset]
self._testlabels = [l for lbl in labels for l in lbl]
return self
class CIFAR100(CIFAR10):
def __init__(self, datadir, name='cifar100', url=CIFAR100_URL, md5=CIFAR100_MD5):
self._name = name
self._datadir = vipy.util.remkdir(datadir)
self._subdir = 'cifar-100-python'
if not os.path.exists(os.path.join(datadir, self._subdir, 'train')):
print('[vipy.data.cifar10]: downloading CIFAR-100 to "%s"' % self._datadir)
vipy.downloader.download_and_unpack(url, self._datadir, md5=md5)
self._train_archives = [os.path.join(datadir, self._subdir, 'train')]
self._test_archives = [os.path.join(datadir, self._subdir, 'test')]
f = os.path.join(self._datadir, self._subdir, 'meta')
assert os.path.exists(f)
with open(f, 'rb') as fo:
d = pickle.load(fo, encoding='bytes')
self._classes = [x.decode("utf-8") for x in d[b'fine_label_names']]
self._coarse_classes = [x.decode("utf-8") for x in d[b'coarse_label_names']]
self._trainset()
self._testset()
def _trainset(self):
return super()._trainset(labelkey=b'fine_labels')
def _testset(self):
return super()._testset(labelkey=b'fine_labels')