-
Notifications
You must be signed in to change notification settings - Fork 13
/
train_gui.py
457 lines (350 loc) · 20.6 KB
/
train_gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
import tensorflow as tf
import numpy as np
import copy
from PIL import Image
from tqdm import tqdm
from model_layers.models import UNET
from model_layers.model_RPN import RPN
from model_layers.anchor_size import anchor_size
from model_layers.rpn_target import RPNTarget
from model_layers.rpn_proposal import RPNProposal
from model_layers.rpn_loss import RPNLoss
from model_layers.seg_loss import segmentation_loss
from model_layers.marker_watershed import marker_watershed
from model_layers.compute_metrics import compute_metrics
from utils.load_data import load_data_train
from utils.tf_utils import optimizer_fun
from utils.anchors import generate_anchors_reference
from utils.generate_anchors import generate_anchors
from utils.test import generate_gt_boxes
from utils.normalization import whole_image_norm, foreground_norm
from utils.losses import smooth_l1_loss
from utils.image_vis import draw_rpn_bbox_pred, draw_gt_boxes, draw_top_nms_proposals, draw_rpn_bbox_targets, draw_rpn_bbox_pred_only
# inspired from https://github.com/tryolabs/luminoth/blob/master/luminoth/models/fasterrcnn/rpn_test.py
def train_NuSeT(self):
"""Train the model, return test loss.
Args:
network (dict): the parameters of the network
return:
accuracy (float)
"""
# Get the training parameters
learning_rate = self.params['lr']
optimizer = self.params['optimizer']
num_epoch = self.params['epochs']
bbox_min_score = self.params['min_score']
nms_thresh = self.params['nms_threshold']
normalization_method = self.params['normalization_method']
# Load the data
# x_train, y_train: training images and corresponding labels
# x_val, y_val: validation images and corresponding labels
# w_train, w_val: training and validation weight matrices for U-Net
# bbox_train, bbox_val: bounding box coordinates for train and validation dataset
x_train, x_val, y_train, y_val, w_train, w_val, bbox_train, bbox_val = load_data_train(self, normalization_method)
# pred_dict and pred_dict_final save all the temp variables
pred_dict_final = {}
# tensor placeholder for training images with labels
train_initial = tf.placeholder(dtype=tf.float32, shape=[1, None, None, 1])
labels = tf.placeholder(dtype=tf.float32, shape=[1, None, None, 1])
# tensor placeholder for weigth matrices and ground truth bounding boxes
edge_weights = tf.placeholder(dtype=tf.float32, shape=[1, None, None, 1])
gt_boxes = tf.placeholder(dtype=tf.float32, shape=[None, 5])
input_shape = tf.shape(train_initial)
input_height = input_shape[1]
input_width = input_shape[2]
im_shape = tf.cast([input_height, input_width], tf.float32)
# number of classes needed to be classified, for our case this equals to 2
# (foreground and background)
nb_classes = 2
# feed the initial image to U-Net, we expect 2 outputs:
# 1. feat_map of shape (1,input_height/16,input_width/16,1024), which will be passed to the
# region proposal network
# 2. final_logits of shape(1,input_height,input_width,2), which is the prediction from U-net
with tf.variable_scope('model_U-Net') as scope:
final_logits, feat_map = UNET(nb_classes, train_initial)
# The final_logits has 2 channels for foreground/background softmax scores,
# then we get prediction with larger score for each pixel
pred_masks = tf.argmax(final_logits, axis=3)
pred_masks = tf.reshape(pred_masks,[input_height,input_width])
pred_masks = tf.to_float(pred_masks)
# Dynamic anchor base size calculated from median cell lengths
base_size = anchor_size(tf.reshape(labels,[input_height,input_width]))
# scales and ratios are used to generate different anchors
scales = np.array([ 0.5, 1, 2])
ratios = np.array([ 0.125, 0.25, 0.5, 1, 2, 4, 8])
# stride is to control how sparse we want to place anchors across the image
# stride = 16 means to place an anchor every 16 pixels on the original image
stride = 16
# Generate the anchor reference with respect to the original image
ref_anchors = generate_anchors_reference(base_size, ratios, scales)
num_ref_anchors = scales.shape[0] * ratios.shape[0]
feat_height = input_height / stride
feat_width = input_width / stride
# Generate all the anchors based on ref_anchors
all_anchors = generate_anchors(ref_anchors, stride, [feat_height,feat_width])
num_anchors = all_anchors.shape[0]
with tf.variable_scope('model_RPN') as scope:
prediction_dict = RPN(feat_map, num_ref_anchors)
# Get the tensors from the dict
rpn_cls_prob = prediction_dict['rpn_cls_prob']
rpn_bbox_pred = prediction_dict['rpn_bbox_pred']
proposal_prediction = RPNProposal(rpn_cls_prob, rpn_bbox_pred, all_anchors, im_shape, nms_thresh)
pred_dict_final['all_anchors'] = tf.cast(all_anchors, tf.float32)
pred_dict_final['gt_bboxes'] = gt_boxes
prediction_dict['proposals'] = proposal_prediction['proposals']
prediction_dict['scores'] = proposal_prediction['scores']
# When training we use a separate module to calculate the target
# values we want to output.
(rpn_cls_target, rpn_bbox_target,rpn_max_overlap) = RPNTarget(all_anchors, num_anchors, gt_boxes, im_shape)
prediction_dict['rpn_cls_target'] = rpn_cls_target
prediction_dict['rpn_bbox_target'] = rpn_bbox_target
pred_dict_final['rpn_prediction'] = prediction_dict
scores = pred_dict_final['rpn_prediction']['scores']
proposals = pred_dict_final['rpn_prediction']['proposals']
pred_masks_watershed = tf.to_float(marker_watershed(scores, proposals, pred_masks, min_score=bbox_min_score))
# Loss is defined as rpn loss(class loss + bounding box loss) +
# segmentation loss(default is the sum of soft dice and cross-entropy)
rpn_loss = RPNLoss(prediction_dict)
RPN_loss = rpn_loss['rpn_cls_loss'] + rpn_loss['rpn_reg_loss']
SEG_loss = segmentation_loss(final_logits, pred_masks_watershed, labels, edge_weights, mode = 'COMBO')
final_loss = RPN_loss + SEG_loss
# If training with just U-Net, then only include segmentation loss
#final_loss = SEG_loss
# Metrics are pixel accuracy, mean IU, mean accuracy, root mean squared error
metrics, metrics_op = compute_metrics(pred_masks, labels)
pred_dict_final['unet_mask'] = pred_masks
# get the optimizer
gen_train_op = optimizer_fun(optimizer, final_loss, learning_rate=learning_rate)
# start point for training, and end point for graph
sess = tf.Session()
sess.run(tf.global_variables_initializer())
num_batches = len(x_train)
num_batches_val = len(x_val)
saver = tf.train.Saver()
# Restore the model from the trained network
# saver.restore(sess,'./Network/whole_norm.ckpt')
if self.usingCL:
if normalization_method == 'wn':
print('Start whole image Norm. training ...')
if normalization_method == 'fg':
print('Start Foreground Norm. training ...')
else:
if normalization_method == 'wn':
self.training_results.set('Start whole image Norm. training ...')
self.window.update()
if normalization_method == 'fg':
self.training_results.set('Start Foreground Norm. training ...')
self.window.update()
# training images indexes will be shuffled at every epoch during training
idx = np.arange(num_batches)
best_IU = 0
if normalization_method == 'wn':
self.whole_norm_y_pred = []
# The print statement at the end of the loop (line 272) fail without these vars being initialized
loss_total = 0
cls_loss = 0
reg_loss = 0
seg_loss = 0
_mean_IU = 0
_f1 = 0
_pixel_accuracy = 0
for iteration in range(0,num_epoch):
# The batch pointer to validation data
j = 0
sess.run(tf.local_variables_initializer())
# shuffle the sequence of the training data for the current epoch
np.random.shuffle(idx)
for i in tqdm(range(0,num_batches)):
if not self.usingCL:
self.train_progress_var.set(i/num_batches*100)
self.window.update()
# Generate the batch data from training data and training label
batch_data = x_train[idx[i]]
batch_data_shape = batch_data.shape
batch_data = np.reshape(batch_data, [1,batch_data_shape[0],batch_data_shape[1],1])
batch_label = np.reshape(y_train[idx[i]], [1,batch_data_shape[0],batch_data_shape[1],1])
batch_edge = np.reshape(w_train[idx[i]], [1,batch_data_shape[0],batch_data_shape[1],1])
batch_bbox = bbox_train[idx[i]]
temp_pred = []
# Skip if this batch does not contain any object (bounding box is null)
if batch_bbox.size > 0:
# Here include the optimizer to actually perform learning
sess.run([gen_train_op], feed_dict={train_initial:batch_data, gt_boxes:batch_bbox, labels:batch_label, edge_weights:batch_edge})
# Only calculate the accuracy and loss after the training epoch
if i == num_batches - 1:
while j < num_batches_val:
# Generate the batch data from val data and val label
batch_data = x_val[j]
batch_data_shape = batch_data.shape
batch_data = np.reshape(batch_data, [1,batch_data_shape[0],batch_data_shape[1],1])
batch_label = np.reshape(y_val[j], [1,batch_data_shape[0],batch_data_shape[1],1])
batch_edge = np.reshape(w_val[j], [1,batch_data_shape[0],batch_data_shape[1],1])
batch_bbox = bbox_val[j]
# At the last 10 turns of whole image normalization training,
# cache the predictions
if iteration >= num_epoch - 10 and normalization_method == 'wn':
temp_pred.append(sess.run(pred_masks,
feed_dict={train_initial:batch_data, gt_boxes:batch_bbox, labels:batch_label, edge_weights:batch_edge}))
if batch_bbox.size > 0:
# Here get the accuracy and loss for each batch in validation cycle
loss_temp, rpnloss_temp, segloss_temp = sess.run([final_loss, rpn_loss, SEG_loss], feed_dict={train_initial:batch_data, gt_boxes:batch_bbox, labels:batch_label, edge_weights:batch_edge})
sess.run([metrics_op], feed_dict={train_initial:batch_data, gt_boxes:batch_bbox, labels:batch_label, edge_weights:batch_edge})
if j == num_batches_val - 1:
metrics_all = sess.run(metrics, feed_dict={train_initial:batch_data, gt_boxes:batch_bbox, labels:batch_label, edge_weights:batch_edge})
_mean_IU = metrics_all['global']['mean_IU']
_pixel_accuracy = metrics_all['global']['pixel_accuracy']
_f1 = 2 * _mean_IU / (1 + _mean_IU)
_rmse = metrics_all['global']['rmse']
# Get moving average of metrics and losses
if j == 0:
loss_total = loss_temp
cls_loss = rpnloss_temp['rpn_cls_loss']
reg_loss = rpnloss_temp['rpn_reg_loss']
seg_loss = segloss_temp
else:
loss_total = (1 - 1 / (j + 1)) * loss_total + 1 / (j + 1) * loss_temp
cls_loss = (1 - 1 / (j + 1)) * cls_loss + 1 / (j + 1) * rpnloss_temp['rpn_cls_loss']
reg_loss = (1 - 1 / (j + 1)) * reg_loss + 1 / (j + 1) * rpnloss_temp['rpn_reg_loss']
seg_loss = (1 - 1 / (j + 1)) * seg_loss + 1 / (j + 1) * segloss_temp
j = j + 1
print('Epoch: %d - loss: %.2f - cls_loss: %.2f - reg_loss: %.2f - seg_loss: %.2f - mean_IU: %.4f - f1: %.4f - pixel_accuracy: %.4f' % (iteration, loss_total, cls_loss, reg_loss, seg_loss, _mean_IU, _f1, _pixel_accuracy))
if not self.usingCL:
self.training_results.set('Epoch ' + str(iteration) +
', loss ' + '{0:.2f}'.format(loss_total) + ', mean IU ' + '{0:.2f}'.format(_mean_IU))
self.window.update()
# Keep track of the best model in the last 10 epoches and use that as the best model
if iteration >= num_epoch - 10 and normalization_method == 'wn' and _mean_IU > best_IU:
best_IU = _mean_IU
self.whole_norm_y_pred = copy.deepcopy(temp_pred)
saver.save(sess, './Network/whole_norm.ckpt')
if iteration >= num_epoch - 10 and normalization_method == 'fg' and _mean_IU > best_IU:
best_IU = _mean_IU
saver.save(sess, './Network/foreground.ckpt')
temp_pred = []
sess.close()
# Train the pure U-Net model
def train_UNet(self):
"""Train the model, return test loss.
Args:
network (dict): the parameters of the network
return:
accuracy (float)
"""
# Get the training parameters
learning_rate = self.params['lr']
optimizer = self.params['optimizer']
num_epoch = self.params['epochs']
normalization_method = self.params['normalization_method']
# Load the data
# x_train, y_train: training images and corresponding labels
# x_val, y_val: validation images and corresponding labels
# w_train, w_val: training and validation weight matrices for U-Net
# bbox_train, bbox_val: not used in unet model
x_train, x_val, y_train, y_val, w_train, w_val, bbox_train, bbox_val = load_data_train(self, normalization_method)
# pred_dict and pred_dict_final save all the temp variables
pred_dict_final = {}
# tensor placeholder for training images with labels
train_initial = tf.placeholder(dtype=tf.float32, shape=[1, None, None, 1])
labels = tf.placeholder(dtype=tf.float32, shape=[1, None, None, 1])
# tensor placeholder for weigth matrices and ground truth bounding boxes
edge_weights = tf.placeholder(dtype=tf.float32, shape=[1, None, None, 1])
input_shape = tf.shape(train_initial)
input_height = input_shape[1]
input_width = input_shape[2]
im_shape = tf.cast([input_height, input_width], tf.float32)
# number of classes needed to be classified, for our case this equals to 2
# (foreground and background)
nb_classes = 2
# feed the initial image to U-Net, we expect 2 outputs:
# 1. feat_map of shape (1,input_height/16,input_width/16,1024), which will be passed to the
# region proposal network
# 2. final_logits of shape(1,input_height,input_width,2), which is the prediction from U-net
with tf.variable_scope('model_U-Net') as scope:
final_logits, feat_map = UNET(nb_classes, train_initial)
# The final_logits has 2 channels for foreground/background softmax scores,
# then we get prediction with larger score for each pixel
pred_masks = tf.argmax(final_logits, axis=3)
pred_masks = tf.reshape(pred_masks,[input_height,input_width])
pred_masks = tf.to_float(pred_masks)
SEG_loss = segmentation_loss(final_logits, pred_masks, labels, edge_weights, mode = 'COMBO')
final_loss = SEG_loss
# Metrics are pixel accuracy, mean IU, mean accuracy, root mean squared error
metrics, metrics_op = compute_metrics(pred_masks, labels)
pred_dict_final['unet_mask'] = pred_masks
# get the optimizer
gen_train_op = optimizer_fun(optimizer, final_loss, learning_rate=learning_rate)
# start point for training, and end point for graph
sess = tf.Session()
sess.run(tf.global_variables_initializer())
num_batches = len(x_train)
num_batches_val = len(x_val)
saver = tf.train.Saver()
if not self.usingCL:
self.training_results.set('U-Net: Start training ...')
self.window.update()
# training images indexes will be shuffled at every epoch during training
idx = np.arange(num_batches)
best_IU = 0
if normalization_method == 'wn':
self.whole_norm_y_pred = []
for iteration in range(num_epoch):
# The batch pointer to validation data
j = 0
sess.run(tf.local_variables_initializer())
# shuffle the sequence of the training data for the current epoch
np.random.shuffle(idx)
for i in tqdm(range(0,num_batches)):
if not self.usingCL:
self.train_progress_var.set(i/num_batches*100)
self.window.update()
# Generate the batch data from training data and training label
batch_data = x_train[idx[i]]
batch_data_shape = batch_data.shape
batch_data = np.reshape(batch_data, [1,batch_data_shape[0],batch_data_shape[1],1])
batch_label = np.reshape(y_train[idx[i]], [1,batch_data_shape[0],batch_data_shape[1],1])
batch_edge = np.reshape(w_train[idx[i]], [1,batch_data_shape[0],batch_data_shape[1],1])
temp_pred = []
# Here include the optimizer to actually perform learning
sess.run([gen_train_op], feed_dict={train_initial:batch_data, labels:batch_label, edge_weights:batch_edge})
# Only calculate the accuracy and loss after the training epoch
if i == num_batches - 1:
while j < num_batches_val:
# Generate the batch data from val data and val label
batch_data = x_val[j]
batch_data_shape = batch_data.shape
batch_data = np.reshape(batch_data, [1,batch_data_shape[0],batch_data_shape[1],1])
batch_label = np.reshape(y_val[j], [1,batch_data_shape[0],batch_data_shape[1],1])
batch_edge = np.reshape(w_val[j], [1,batch_data_shape[0],batch_data_shape[1],1])
# At the last 10 turns of whole image normalization training,
# cache the predictions
if iteration >= num_epoch - 10 and normalization_method == 'wn':
temp_pred.append(sess.run(pred_masks,
feed_dict={train_initial:batch_data, labels:batch_label, edge_weights:batch_edge}))
loss_temp = sess.run(final_loss, feed_dict={train_initial:batch_data, labels:batch_label, edge_weights:batch_edge})
sess.run([metrics_op], feed_dict={train_initial:batch_data, labels:batch_label, edge_weights:batch_edge})
if j == num_batches_val - 1:
metrics_all = sess.run(metrics, feed_dict={train_initial:batch_data, labels:batch_label, edge_weights:batch_edge})
_mean_IU = metrics_all['global']['mean_IU']
_pixel_accuracy = metrics_all['global']['pixel_accuracy']
_f1 = 2 * _mean_IU / (1 + _mean_IU)
_rmse = metrics_all['global']['rmse']
# Get moving average of metrics and losses
if j == 0:
loss_total = loss_temp
else:
loss_total = (1 - 1 / (j + 1)) * loss_total + 1 / (j + 1) * loss_temp
j = j + 1
print('Epoch: %d - loss: %.2f - mean_IU: %.4f - f1: %.4f - pixel_accuracy: %.4f' % (iteration, loss_total, _mean_IU, _f1, _pixel_accuracy))
if not self.usingCL:
self.training_results.set('Epoch ' + str(iteration) +
', loss ' + '{0:.2f}'.format(loss_total) + ', mean IU ' + '{0:.2f}'.format(_mean_IU))
self.window.update()
# Keep track of the best model in the last 10 epoches and use that as the best model
if iteration >= num_epoch - 10 and normalization_method == 'wn' and _mean_IU > best_IU:
best_IU = _mean_IU
self.whole_norm_y_pred = copy.deepcopy(temp_pred)
saver.save(sess, './Network/UNet_whole_norm.ckpt')
if iteration >= num_epoch - 10 and normalization_method == 'fg' and _mean_IU > best_IU:
best_IU = _mean_IU
saver.save(sess, './Network/UNet_foreground.ckpt')
sess.close()