-
Notifications
You must be signed in to change notification settings - Fork 23
/
roidb.py
214 lines (182 loc) · 7.24 KB
/
roidb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""Transform a roidb into a trainable roidb by adding a bunch of metadata."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datasets
import numpy as np
from model.utils.config import cfg
from datasets.factory import get_imdb
import PIL
import pdb
import collections
def prepare_roidb(imdb):
"""Enrich the imdb's roidb by adding some derived quantities that
are useful for training. This function precomputes the maximum
overlap, taken over ground-truth boxes, between each ROI and
each ground-truth box. The class with maximum overlap is also
recorded.
"""
roidb = imdb.roidb
if not (imdb.name.startswith('coco') or imdb.name.startswith('vg')):
sizes = [PIL.Image.open(imdb.image_path_at(i)).size
for i in range(imdb.num_images)]
for i in range(len(imdb.image_index)):
roidb[i]['img_id'] = imdb.image_id_at(i)
roidb[i]['image'] = imdb.image_path_at(i)
if not (imdb.name.startswith('coco') or imdb.name.startswith('vg')):
roidb[i]['width'] = sizes[i][0]
roidb[i]['height'] = sizes[i][1]
# need gt_overlaps as a dense array for argmax
gt_overlaps = roidb[i]['gt_overlaps'].toarray()
# max overlap with gt over classes (columns)
max_overlaps = gt_overlaps.max(axis=1)
# gt class that had the max overlap
max_classes = gt_overlaps.argmax(axis=1)
roidb[i]['max_classes'] = max_classes
roidb[i]['max_overlaps'] = max_overlaps
# sanity checks
# max overlap of 0 => class should be zero (background)
zero_inds = np.where(max_overlaps == 0)[0]
assert all(max_classes[zero_inds] == 0)
# max overlap > 0 => class should not be zero (must be a fg class)
nonzero_inds = np.where(max_overlaps > 0)[0]
assert all(max_classes[nonzero_inds] != 0)
def update_keyvalue(rdb, idx):
## update the roidb keyvaule
r = rdb.copy()
keys = ['gt_classes','boxes']
for k in keys:
if isinstance(r[k], list):
r[k] = [rdb[k][idx]]
elif isinstance(r[k], np.ndarray):
r[k] = np.array(rdb[k[idx]], dtype=r[k].dtype)
return r
def filter_class_roidb(roidb, shot, imdb):
class_count = collections.defaultdict(int)
for cls in range(1, len(imdb.classes)):
class_count[cls] = 0
new_roidb = []
length = len(roidb) // 2 # consider the flipped
for idx, rdb in enumerate(roidb[:length]):
boxes = []
gt_classes = []
gt_overlaps = []
max_classes = []
max_overlaps = []
boxes_flipped = []
gt_classes_flipped = []
gt_overlaps_flipped = []
max_classes_flipped = []
max_overlaps_flipped = []
rdb_flipped = roidb[idx + length]
for i in range(len(rdb['gt_classes'])):
cls_id = rdb['gt_classes'][i]
if class_count[cls_id] < shot and cls_id > 15:
boxes.append(rdb['boxes'][i])
gt_classes.append(rdb['gt_classes'][i])
gt_overlaps.append(rdb['gt_overlaps'][i])
max_classes.append(rdb['max_classes'][i])
max_overlaps.append(rdb['max_overlaps'][i])
boxes_flipped.append(rdb_flipped['boxes'][i])
gt_classes_flipped.append(rdb_flipped['gt_classes'][i])
gt_overlaps_flipped.append(rdb_flipped['gt_overlaps'][i])
max_classes_flipped.append(rdb_flipped['max_classes'][i])
max_overlaps_flipped.append(rdb_flipped['max_overlaps'][i])
class_count[cls_id] += 1
elif cls_id <= 15:
boxes.append(rdb['boxes'][i])
gt_classes.append(rdb['gt_classes'][i])
gt_overlaps.append(rdb['gt_overlaps'][i])
max_classes.append(rdb['max_classes'][i])
max_overlaps.append(rdb['max_overlaps'][i])
boxes_flipped.append(rdb_flipped['boxes'][i])
gt_classes_flipped.append(rdb_flipped['gt_classes'][i])
gt_overlaps_flipped.append(rdb_flipped['gt_overlaps'][i])
max_classes_flipped.append(rdb_flipped['max_classes'][i])
max_overlaps_flipped.append(rdb_flipped['max_overlaps'][i])
class_count[cls_id] += 1
if len(boxes) > 0:
new_roidb.append(
{'boxes': np.array(boxes, dtype=np.uint16), 'gt_classes': np.array(gt_classes, dtype=np.int32),
'gt_overlaps': gt_overlaps, 'flipped': rdb['flipped'], 'img_id': rdb['img_id'],
'image': rdb['image'],
'width': rdb['width'], 'height': rdb['height'], 'max_classes': np.array(max_classes),
'need_crop': rdb['need_crop'],
'max_overlaps': np.array(max_overlaps, dtype=np.float32)})
new_roidb.append(
{'boxes': np.array(boxes_flipped, dtype=np.uint16),
'gt_classes': np.array(gt_classes_flipped, dtype=np.int32),
'gt_overlaps': gt_overlaps_flipped, 'flipped': rdb_flipped['flipped'],
'img_id': rdb_flipped['img_id'],
'image': rdb_flipped['image'],
'width': rdb_flipped['width'], 'height': rdb_flipped['height'],
'max_classes': np.array(max_classes_flipped),
'need_crop': rdb_flipped['need_crop'],
'max_overlaps': np.array(max_overlaps_flipped, dtype=np.float32)})
return new_roidb
def rank_roidb_ratio(roidb):
# rank roidb based on the ratio between width and height.
ratio_large = 2 # largest ratio to preserve.
ratio_small = 0.5 # smallest ratio to preserve.
ratio_list = []
for i in range(len(roidb)):
width = roidb[i]['width']
height = roidb[i]['height']
ratio = width / float(height)
if ratio > ratio_large:
roidb[i]['need_crop'] = 1
ratio = ratio_large
elif ratio < ratio_small:
roidb[i]['need_crop'] = 1
ratio = ratio_small
else:
roidb[i]['need_crop'] = 0
ratio_list.append(ratio)
ratio_list = np.array(ratio_list)
ratio_index = np.argsort(ratio_list)
return ratio_list[ratio_index], ratio_index
def filter_roidb(roidb):
# filter the image without bounding box.
print('before filtering, there are %d images...' % (len(roidb)))
i = 0
while i < len(roidb):
if len(roidb[i]['boxes']) == 0:
del roidb[i]
i -= 1
i += 1
print('after filtering, there are %d images...' % (len(roidb)))
return roidb
def combined_roidb(imdb_names, training=True):
"""
Combine multiple roidbs
"""
def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
print('Appending horizontally-flipped training examples...')
imdb.append_flipped_images()
print('done')
print('Preparing training data...')
prepare_roidb(imdb)
print('done')
return imdb.roidb
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print('Loaded dataset `{:s}` for training'.format(imdb.name))
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD) #gt
print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))
roidb = get_training_roidb(imdb)
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
tmp = get_imdb(imdb_names.split('+')[1])
imdb = datasets.imdb.imdb(imdb_names, tmp.classes)
else:
imdb = get_imdb(imdb_names)
if training:
roidb = filter_roidb(roidb)
ratio_list, ratio_index = rank_roidb_ratio(roidb)
return imdb, roidb, ratio_list, ratio_index