1515from returnn .tf .layers .basic import LayerBase , get_layer_class
1616import returnn .tf .compat as tf_compat
1717import returnn .tf .util .basic as tf_util
18- from returnn .tf .util .basic import Data , DimensionTag , reuse_name_scope , VariableAssigner
18+ from returnn .tf .util .basic import Data , DimensionTag , reuse_name_scope , VariableAssigner , get_checkpoint_filepattern
1919
2020
2121class DataNotFound (Exception ):
@@ -3473,7 +3473,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34733473 ignore_params = (), ignore_params_prefixes = (), var_name_mapping = None ,
34743474 network = None ):
34753475 """
3476- :param str filename: filepattern for NewCheckpointReader
3476+ :param str filename: filepattern for NewCheckpointReader or .index/.meta file path
34773477 :param list[tf.Variable|tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject] saveable_params:
34783478 :param str params_prefix: expect that all vars in saveable_params have this prefix, and remove it
34793479 :param str load_if_prefix: if given, only load variables with a name containing this string.
@@ -3486,7 +3486,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34863486 renamed vars in the checkpoint
34873487 :param TFNetwork network:
34883488 """
3489- self .filename = filename
3489+ self .filepattern = get_checkpoint_filepattern ( filename )
34903490 self .network = network
34913491 self .ignore_missing = ignore_missing
34923492 self .params_prefix = params_prefix
@@ -3510,7 +3510,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
35103510 continue
35113511 self .saveable_params .append (param )
35123512 assert count > 0 , "%s: no saveable vars" % self
3513- self .reader = tf_compat .v1 .train .NewCheckpointReader (filename )
3513+ self .reader = tf_compat .v1 .train .NewCheckpointReader (self . filepattern )
35143514 self .net_vars = [v for v in self .saveable_params if isinstance (v , tf .Variable )]
35153515 self .net_saveables = [v for v in self .saveable_params if not isinstance (v , tf .Variable )]
35163516 # All variables in the checkpoint:
@@ -3918,7 +3918,7 @@ def get_lazy_dict(self):
39183918 if self .ignore_missing and v_name not in var_name_map :
39193919 print (
39203920 "Warning, did not find match for var %r (%r, params_prefix %r, load_if_prefix %r) in checkpoint %r." % (
3921- v , v_name , self .params_prefix , self .load_if_prefix , self .filename ), file = log .v3 )
3921+ v , v_name , self .params_prefix , self .load_if_prefix , self .filepattern ), file = log .v3 )
39223922 continue
39233923 variable_values [v ] = self .VariableValue (value = var_name_map [v_name ]())
39243924 assert variable_values , "no vars to load; saveable vars are %r. load_if_prefix %r." % (
0 commit comments