-
Notifications
You must be signed in to change notification settings - Fork 0
/
deblurring_train.py
110 lines (89 loc) · 4.22 KB
/
deblurring_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import tensorflow as tf
import numpy as np
import resnet_model
import losses
import data_provider
import utils
from tensorflow.python.platform import tf_logging as logging
from pathlib import Path
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('initial_learning_rate', 0.0005,
'''Initial learning rate.''')
tf.app.flags.DEFINE_float('num_epochs_per_decay', 5.0,
'''Epochs after which learning rate decays.''')
tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.97,
'''Learning rate decay factor.''')
tf.app.flags.DEFINE_integer('batch_size', 32, '''The batch size to use.''')
tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
'''How many preprocess threads to use.''')
tf.app.flags.DEFINE_string('train_dir', 'ckpt/train',
'''Directory where to write event logs '''
'''and checkpoint.''')
tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', '',
'''If specified, restore this pretrained model '''
'''before beginning any training.''')
tf.app.flags.DEFINE_string(
'pretrained_resnet_checkpoint_path', '',
'''If specified, restore this pretrained resnet '''
'''before beginning any training.'''
'''This restores only the weights of the resnet model''')
tf.app.flags.DEFINE_integer('max_steps', 100000,
'''Number of batches to run.''')
tf.app.flags.DEFINE_string('train_device', '/gpu:0',
'''Device to train with.''')
# The decay to use for the moving average.
MOVING_AVERAGE_DECAY = 0.9999
def restore_resnet(sess, path):
def name_in_checkpoint(var):
name = '/'.join(var.name.split('/')[2:])
name = name.split(':')[0]
if 'Adam' in name:
return None
return name
variables_to_restore = slim.get_variables_to_restore(
include=["net/multiscale/resnet_v1_50"])
variables_to_restore = {name_in_checkpoint(var): var
for var in variables_to_restore if name_in_checkpoint(var) is not None}
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, path)
def train():
g = tf.Graph()
with g.as_default():
# Load datasets.
provider = data_provider.Deblurring()
images, deblurred = provider.get('deblurred')
# Define model graph.
with tf.variable_scope('net'):
with slim.arg_scope([slim.batch_norm, slim.layers.dropout],
is_training=True):
scales = [1, 2, 4]
prediction, pyramid = resnet_model.multiscale_deblurring_net(images, scales=scales)
# Add a cosine loss to every scale and the combined output.
for net, level_name in zip([prediction] + pyramid, ['pred'] + scales):
loss = losses.smooth_l1(net, deblurred)
slim.losses.add_loss(loss)
tf.scalar_summary('losses/loss at {}'.format(level_name), loss)
total_loss = slim.losses.get_total_loss()
tf.scalar_summary('losses/total loss', total_loss)
tf.image_summary('blurred', images)
tf.image_summary('deblurred', deblurred)
tf.image_summary('pred', prediction)
optimizer = tf.train.AdamOptimizer(FLAGS.initial_learning_rate)
with tf.Session(graph=g) as sess:
if FLAGS.pretrained_resnet_checkpoint_path:
restore_resnet(sess, FLAGS.pretrained_resnet_checkpoint_path)
if FLAGS.pretrained_model_checkpoint_path:
variables_to_restore = slim.get_variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, FLAGS.pretrained_model_checkpoint_path)
train_op = slim.learning.create_train_op(total_loss,
optimizer,
summarize_gradients=True)
logging.set_verbosity(1)
slim.learning.train(train_op,
FLAGS.train_dir,
save_summaries_secs=60,
save_interval_secs=600)
if __name__ == '__main__':
train()