-
Notifications
You must be signed in to change notification settings - Fork 714
/
Copy pathtrain.py
executable file
·79 lines (66 loc) · 2.65 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python
import numpy as np
import os
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter
import time
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("batch_size", 10, "The batch size to train")
flags.DEFINE_integer("epoch_number", 10, "Number of epochs to run trainer")
flags.DEFINE_integer("steps_to_validate", 1,
"Steps to validate and print loss")
flags.DEFINE_string("checkpoint_dir", "./checkpoint/",
"indicates the checkpoint dirctory")
flags.DEFINE_string("model_path", "./model/", "The export path of the model")
flags.DEFINE_integer("export_version", 1, "The version number of the model")
def main():
# Define training data
x = np.ones(FLAGS.batch_size)
y = np.ones(FLAGS.batch_size)
# Define the model
X = tf.placeholder(tf.float32, shape=[None])
Y = tf.placeholder(tf.float32, shape=[None])
w = tf.Variable(1.0, name="weight")
b = tf.Variable(1.0, name="bias")
loss = tf.square(Y - tf.mul(X, w) - b)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
predict_op = tf.mul(X, w) + b
saver = tf.train.Saver()
checkpoint_dir = FLAGS.checkpoint_dir
checkpoint_file = checkpoint_dir + "/checkpoint.ckpt"
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# Start the session
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
print("Continue training from the model {}".format(ckpt.model_checkpoint_path))
saver.restore(sess, ckpt.model_checkpoint_path)
# Start training
start_time = time.time()
for epoch in range(FLAGS.epoch_number):
sess.run(train_op, feed_dict={X: x, Y: y})
# Start validating
if epoch % FLAGS.steps_to_validate == 0:
end_time = time.time()
print("[{}] Epoch: {}".format(end_time - start_time, epoch))
saver.save(sess, checkpoint_file)
start_time = end_time
# Print model variables
w_value, b_value = sess.run([w, b])
print("The model of w: {}, b: {}".format(w_value, b_value))
# Export the model
print("Exporting trained model to {}".format(FLAGS.model_path))
model_exporter = exporter.Exporter(saver)
model_exporter.init(
sess.graph.as_graph_def(),
named_graph_signatures={
'inputs': exporter.generic_signature({"features": X}),
'outputs': exporter.generic_signature({"prediction": predict_op})
})
model_exporter.export(FLAGS.model_path, tf.constant(FLAGS.export_version), sess)
print 'Done exporting!'
if __name__ == "__main__":
main()