From 9eff9a695f7f135175ef8d41fae1e07d4cf8901f Mon Sep 17 00:00:00 2001 From: Yabin Zheng Date: Thu, 22 Dec 2016 15:38:53 +0800 Subject: [PATCH] Update files.py, speed up model saving and restoring process. The run() and eval() will run the whole graph from scratch, so combining ops to an array and executing together will result in significant speed-up. --- tensorlayer/files.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorlayer/files.py b/tensorlayer/files.py index 0e9253baa..7b9568a7d 100644 --- a/tensorlayer/files.py +++ b/tensorlayer/files.py @@ -646,14 +646,14 @@ def save_npz(save_list=[], name='model.npz', sess=None): """ ## save params into a list save_list_var = [] - for k, value in enumerate(save_list): - if sess: - save_list_var.append( sess.run(value) ) - else: - try: - save_list_var.append( value.eval() ) - except: - print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)") + if sess: + save_list_var = sess.run(save_list) + else: + try: + for k, value in enumerate(save_list): + save_list_var.append(value.eval()) + except: + print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)") np.savez(name, params=save_list_var) save_list_var = None del save_list_var @@ -734,9 +734,10 @@ def assign_params(sess, params, network): ---------- - `Assign value to a TensorFlow variable `_ """ + ops = [] for idx, param in enumerate(params): - assign_op = network.all_params[idx].assign(param) - sess.run(assign_op) + ops.append(network.all_params[idx].assign(param)) + sess.run(ops)