-
Notifications
You must be signed in to change notification settings - Fork 3
/
ethzshapes.py
44 lines (35 loc) · 1.89 KB
/
ethzshapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
from vipy.util import remkdir, isjpg
from vipy.image import ImageDetection
import vipy.downloader
URL = 'http://www.vision.ee.ethz.ch/datasets_extra/ethz_shape_classes_v12.tgz'
SHA1 = 'ae9b8fad2d170e098e5126ea9181d0843505a84b'
SUBDIR = 'ETHZShapeClasses-V1.2'
LABELS = ['Applelogos','Bottles','Giraffes','Mugs','Swans']
class ETHZShapes(object):
def __init__(self, datadir):
"""ETHZShapes, provide a datadir='/path/to/store/ethzshapes' """
self.datadir = remkdir(datadir)
def __repr__(self):
return str('<vipy.data.ethzshapes: "%s">' % self.datadir)
def download_and_unpack(self):
vipy.downloader.download_and_unpack(URL, self.datadir, sha1=SHA1)
def dataset(self):
categorydir = LABELS
imlist = []
for (idx_category, category) in enumerate(categorydir):
imdir = os.path.join(self.datadir, SUBDIR, category)
for filename in os.listdir(imdir):
if isjpg(filename) and not filename.startswith('.'):
# Write image
im = os.path.join(self.datadir, SUBDIR, category, filename)
# Write detections
gtfile = os.path.join(self.datadir, SUBDIR, category, os.path.splitext(os.path.basename(filename))[0] + '_' + category.lower() + '.groundtruth')
if not os.path.isfile(gtfile):
gtfile = os.path.join(self.datadir, SUBDIR, category, os.path.splitext(os.path.basename(filename))[0] + '_' + category.lower() + 's.groundtruth') # plural hack
for line in open(gtfile,'r'):
if line.strip() == '':
continue
(xmin,ymin,xmax,ymax) = line.strip().split()
imlist.append(ImageDetection(filename=im, category=category, xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax))
return imlist