Skip to content

Commit

Permalink
load_checkpoint_vars supports npz
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Aug 8, 2020
1 parent 43a44c1 commit 57f542d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tensorpack/tfutils/varmanip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
__all__ = ['SessionUpdate', 'dump_session_params',
'load_chkpt_vars', 'save_chkpt_vars',
'load_checkpoint_vars', 'save_checkpoint_vars',
'get_checkpoint_path']
'get_checkpoint_path', 'get_all_checkpoints']


def get_savename_from_varname(
Expand Down Expand Up @@ -251,6 +251,10 @@ def load_checkpoint_vars(path):
Returns:
dict: a name:value dict
"""
if path.endswith(".npz"):
ret = dict(np.load(path))
ret = {get_op_tensor_name(k)[0]: v for k, v in ret.items()}
return ret
path = get_checkpoint_path(path)
reader = tfv1.train.NewCheckpointReader(path)
var_names = reader.get_variable_to_shape_map().keys()
Expand Down

0 comments on commit 57f542d

Please sign in to comment.