Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tensorlayer/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,40 @@ def save_npz(save_list=[], name='model.npz', sess=None):
# np.savez(name, **rename_dict)
# print('Model is saved to: %s' % name)

def save_npz_dict(save_list=[], name='model.npz', sess=None):
"""Input parameters and the file name, save parameters as a dictionary into .npz file. Use tl.utils.load_npz_dict() to restore.

Parameters
----------
save_list : a list
Parameters want to be saved.
name : a string or None
The name of the .npz file.
sess : None or Session

Notes
-----
This function tries to avoid a potential broadcasting error raised by numpy.

"""
## save params into a list
save_list_var = []
if sess:
save_list_var = sess.run(save_list)
else:
try:
for k, value in enumerate(save_list):
save_list_var.append(value.eval())
except:
print(" Fail to save model, Hint: pass the session into this function, save_npz_dict(network.all_params, name='model.npz', sess=sess)")
save_var_dict = {str(idx):val for idx, val in enumerate(save_list_var)}
np.savez(name, **save_var_dict)
save_list_var = None
save_var_dict = None
del save_list_var
del save_var_dict
print("[*] %s saved" % name)

def load_npz(path='', name='model.npz'):
"""Load the parameters of a Model saved by tl.files.save_npz().

Expand Down Expand Up @@ -602,6 +636,25 @@ def load_npz(path='', name='model.npz'):
# exit()
# return d.items()[0][1]['params']

def load_npz_dict(path='', name='model.npz'):
"""Load the parameters of a Model saved by tl.files.save_npz_dict().

Parameters
----------
path : a string
Folder path to .npz file.
name : a string or None
The name of the .npz file.

Returns
--------
params : list
A list of parameters in order.
"""
d = np.load( path+name )
saved_list_var = [val[1] for val in sorted(d.items(), key=lambda tup: int(tup[0]))]
return saved_list_var

def assign_params(sess, params, network):
"""Assign the given parameters to the TensorLayer network.

Expand Down