/
train.py
executable file
·162 lines (133 loc) · 5.96 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
"""Train a Fast R-CNN network."""
import caffe
from fast_rcnn.config import cfg
import roi_data_layer.roidb as rdl_roidb
from utils.timer import Timer
import numpy as np
import os
from caffe.proto import caffe_pb2
import google.protobuf.text_format as pb2_text_format
class SolverWrapper(object):
"""A simple wrapper around Caffe's solver.
This wrapper gives us control over he snapshotting process, which we
use to unnormalize the learned bounding-box regression weights.
"""
def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir
if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
# RPN can only use precomputed normalization because there are no
# fixed statistics to compute a priori
assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
if cfg.TRAIN.BBOX_REG:
print 'Computing bounding-box regression targets...'
self.bbox_means, self.bbox_stds = \
rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model)
self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2_text_format.Merge(f.read(), self.solver_param)
self.solver.net.layers[0].set_roidb(roidb)
def snapshot(self):
"""Take a snapshot of the network after unnormalizing the learned
bounding-box regression weights. This enables easy use at test-time.
"""
net = self.solver.net
scale_bbox_params = (cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
net.params.has_key('bbox_pred'))
if scale_bbox_params:
# save original values
orig_0 = net.params['bbox_pred'][0].data.copy()
orig_1 = net.params['bbox_pred'][1].data.copy()
# scale and shift with bbox reg unnormalization; then save snapshot
net.params['bbox_pred'][0].data[...] = \
(net.params['bbox_pred'][0].data *
self.bbox_stds[:, np.newaxis])
net.params['bbox_pred'][1].data[...] = \
(net.params['bbox_pred'][1].data *
self.bbox_stds + self.bbox_means)
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
filename = (self.solver_param.snapshot_prefix + infix +
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
filename = os.path.join(self.output_dir, filename)
net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename)
if scale_bbox_params:
# restore net to original state
net.params['bbox_pred'][0].data[...] = orig_0
net.params['bbox_pred'][1].data[...] = orig_1
return filename
def train_model(self, max_iters):
"""Network training loop."""
last_snapshot_iter = -1
timer = Timer()
model_paths = []
while self.solver.iter < max_iters:
# Make one SGD update
timer.tic()
self.solver.step(1)
timer.toc()
if self.solver.iter % (10 * self.solver_param.display) == 0:
print 'speed: {:.3f}s / iter'.format(timer.average_time)
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
model_paths.append(self.snapshot())
if last_snapshot_iter != self.solver.iter:
model_paths.append(self.snapshot())
return model_paths
def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
print 'Appending horizontally-flipped training examples...'
imdb.append_flipped_images()
print 'done'
print 'Preparing training data...'
rdl_roidb.prepare_roidb(imdb)
print 'done'
return imdb.roidb
def filter_roidb(roidb):
"""Remove roidb entries that have no usable RoIs."""
def is_valid(entry):
# Valid images have:
# (1) At least one foreground RoI OR
# (2) At least one background RoI
overlaps = entry['max_overlaps']
# find boxes with sufficient overlap
fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
(overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
# image is only valid if such boxes exist
valid = len(fg_inds) > 0 or len(bg_inds) > 0
return valid
num = len(roidb)
filtered_roidb = [entry for entry in roidb if is_valid(entry)]
num_after = len(filtered_roidb)
print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
num, num_after)
return filtered_roidb
def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=40000):
"""Train a Fast R-CNN network."""
roidb = filter_roidb(roidb)
sw = SolverWrapper(solver_prototxt, roidb, output_dir,
pretrained_model=pretrained_model)
print 'Solving...'
model_paths = sw.train_model(max_iters)
print 'done solving'
return model_paths