From 53bb646ddc22289a6e43e8bc5047c05589dc98f0 Mon Sep 17 00:00:00 2001 From: CraigLee Date: Wed, 14 Jun 2017 20:27:55 +0800 Subject: [PATCH] Introduce two new functions save_npz_dict and load_npz_dict to avoid a potential broadcasting error in save_npz. --- tensorlayer/files.py | 53 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tensorlayer/files.py b/tensorlayer/files.py index 5480060a0..c930fe249 100644 --- a/tensorlayer/files.py +++ b/tensorlayer/files.py @@ -561,6 +561,40 @@ def save_npz(save_list=[], name='model.npz', sess=None): # np.savez(name, **rename_dict) # print('Model is saved to: %s' % name) +def save_npz_dict(save_list=[], name='model.npz', sess=None): + """Input parameters and the file name, save parameters as a dictionary into .npz file. Use tl.utils.load_npz_dict() to restore. + + Parameters + ---------- + save_list : a list + Parameters want to be saved. + name : a string or None + The name of the .npz file. + sess : None or Session + + Notes + ----- + This function tries to avoid a potential broadcasting error raised by numpy. + + """ + ## save params into a list + save_list_var = [] + if sess: + save_list_var = sess.run(save_list) + else: + try: + for k, value in enumerate(save_list): + save_list_var.append(value.eval()) + except: + print(" Fail to save model, Hint: pass the session into this function, save_npz_dict(network.all_params, name='model.npz', sess=sess)") + save_var_dict = {str(idx):val for idx, val in enumerate(save_list_var)} + np.savez(name, **save_var_dict) + save_list_var = None + save_var_dict = None + del save_list_var + del save_var_dict + print("[*] %s saved" % name) + def load_npz(path='', name='model.npz'): """Load the parameters of a Model saved by tl.files.save_npz(). @@ -602,6 +636,25 @@ def load_npz(path='', name='model.npz'): # exit() # return d.items()[0][1]['params'] +def load_npz_dict(path='', name='model.npz'): + """Load the parameters of a Model saved by tl.files.save_npz_dict(). + + Parameters + ---------- + path : a string + Folder path to .npz file. + name : a string or None + The name of the .npz file. + + Returns + -------- + params : list + A list of parameters in order. + """ + d = np.load( path+name ) + saved_list_var = [val[1] for val in sorted(d.items(), key=lambda tup: int(tup[0]))] + return saved_list_var + def assign_params(sess, params, network): """Assign the given parameters to the TensorLayer network.