forked from david-gpu/srez
-
Notifications
You must be signed in to change notification settings - Fork 1
/
srez_train.py
129 lines (99 loc) · 4.23 KB
/
srez_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import os.path
import scipy.misc
import tensorflow as tf
import time
import sys
FLAGS = tf.app.flags.FLAGS
def _summarize_progress(train_data, feature, label, gene_output, batch, suffix, max_samples=8):
td = train_data
size = [label.shape[1], label.shape[2]]
nearest = tf.image.resize_nearest_neighbor(feature, size)
nearest = tf.maximum(tf.minimum(nearest, 1.0), 0.0)
bicubic = tf.image.resize_bicubic(feature, size)
bicubic = tf.maximum(tf.minimum(bicubic, 1.0), 0.0)
clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0)
image = tf.concat(axis=2, values=[nearest, bicubic, clipped, label])
image = image[0:max_samples,:,:,:]
image = tf.concat(axis=0, values=[image[i,:,:,:] for i in range(max_samples)])
image = td.sess.run(image)
filename = 'batch%06d_%s.png' % (batch, suffix)
filename = os.path.join(FLAGS.train_dir, filename)
scipy.misc.toimage(image, cmin=0., cmax=1.).save(filename)
print(" Saved %s" % (filename,))
return filename
def _save_checkpoint(train_data, batch):
td = train_data
oldname = 'checkpoint_old.txt'
newname = 'checkpoint_new.txt'
oldname = os.path.join(FLAGS.checkpoint_dir, oldname)
newname = os.path.join(FLAGS.checkpoint_dir, newname)
# Delete oldest checkpoint
try:
tf.gfile.Remove(oldname)
tf.gfile.Remove(oldname + '.meta')
except:
pass
# Rename old checkpoint
try:
tf.gfile.Rename(newname, oldname)
tf.gfile.Rename(newname + '.meta', oldname + '.meta')
except:
pass
# Generate new checkpoint
saver = tf.train.Saver()
saver.save(td.sess, newname)
print(" Checkpoint saved")
def train_model(train_data):
td = train_data
summaries = tf.summary.merge_all()
td.sess.run(tf.global_variables_initializer())
lrval = FLAGS.learning_rate_start
start_time = time.time()
done = False
batch = 0
assert FLAGS.learning_rate_half_life % 10 == 0
# Cache test features and labels (they are small)
test_feature, test_label = td.sess.run([td.test_features, td.test_labels])
while not done:
batch += 1
gene_loss = disc_real_loss = disc_fake_loss = -1.234
feed_dict = {td.learning_rate : lrval}
ops = [td.gene_minimize, td.disc_minimize, td.gene_loss, td.disc_real_loss, td.disc_fake_loss]
_, _, gene_loss, disc_real_loss, disc_fake_loss = td.sess.run(ops, feed_dict=feed_dict)
#if batch % 10 == 0:
if batch % 100 == 0:
# Show we are alive
elapsed = int(time.time() - start_time)/60
if FLAGS.train_batch_iterations != -1:
print('Progress[%3d%%], ETA[%4dm], Batch [%4d], G_Loss[%3.3f], D_Real_Loss[%3.3f], D_Fake_Loss[%3.3f]' %
(int(100*batch/FLAGS.train_batch_iterations),
int(elapsed * (FLAGS.train_batch_iterations/batch - 1.)),
batch, gene_loss, disc_real_loss, disc_fake_loss))
else:
print('Progress[%3d%%], ETA[%4dm], Batch [%4d], G_Loss[%3.3f], D_Real_Loss[%3.3f], D_Fake_Loss[%3.3f]' %
(int(100*elapsed/FLAGS.train_time), FLAGS.train_time - elapsed,
batch, gene_loss, disc_real_loss, disc_fake_loss))
sys.stdout.flush()
# Finished?
if FLAGS.train_batch_iterations == -1:
current_progress = elapsed / FLAGS.train_time
if current_progress >= 1.0:
done = True
else:
if batch >= FLAGS.train_batch_iterations:
done = True
# Update learning rate
if batch % FLAGS.learning_rate_half_life == 0:
lrval *= .5
if batch % FLAGS.summary_period == 0:
# Show progress with test features
feed_dict = {td.gene_minput: test_feature}
gene_output = td.sess.run(td.gene_moutput, feed_dict=feed_dict)
_summarize_progress(td, test_feature, test_label, gene_output, batch, 'out')
if batch % FLAGS.checkpoint_period == 0:
# Save checkpoint
_save_checkpoint(td, batch)
_save_checkpoint(td, batch)
print('Finished training!')
sys.stdout.flush()