-
Notifications
You must be signed in to change notification settings - Fork 714
/
Copy pathbenchmark_predict.py
executable file
·67 lines (54 loc) · 2.18 KB
/
benchmark_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python
import numpy as np
import os
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter
import time
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("batch_size", 10, "The batch size to train")
flags.DEFINE_integer("epoch_number", 10, "Number of epochs to run trainer")
flags.DEFINE_integer("steps_to_validate", 1,
"Steps to validate and print loss")
flags.DEFINE_string("checkpoint_dir", "./checkpoint/",
"indicates the checkpoint dirctory")
flags.DEFINE_string("model_path", "./model/", "The export path of the model")
flags.DEFINE_integer("export_version", 1, "The version number of the model")
flags.DEFINE_integer("benchmark_batch_size", 1, "")
flags.DEFINE_integer("benchmark_test_number", 10000, "")
def main():
# Define training data
x = np.ones(FLAGS.batch_size)
y = np.ones(FLAGS.batch_size)
# Define the model
X = tf.placeholder(tf.float32, shape=[None])
Y = tf.placeholder(tf.float32, shape=[None])
w = tf.Variable(1.0, name="weight")
b = tf.Variable(1.0, name="bias")
loss = tf.square(Y - tf.mul(X, w) - b)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
predict_op = tf.mul(X, w) + b
saver = tf.train.Saver()
checkpoint_dir = FLAGS.checkpoint_dir
checkpoint_file = checkpoint_dir + "/checkpoint.ckpt"
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# Start the session
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
print("Continue training from the model {}".format(ckpt.model_checkpoint_path))
saver.restore(sess, ckpt.model_checkpoint_path)
# Start training
start_time = time.time()
request_number = FLAGS.benchmark_test_number
batch_size = FLAGS.benchmark_batch_size
predict_x = np.ones(batch_size)
start_time = time.time()
for i in range(request_number):
sess.run(predict_op, feed_dict={X: predict_x})
end_time = time.time()
print("Average latency is: {} ms".format((end_time - start_time) * 1000 / request_number))
if __name__ == "__main__":
main()