diff --git a/returnn/engine/base.py b/returnn/engine/base.py index d3bacf829..cc7dbfca2 100644 --- a/returnn/engine/base.py +++ b/returnn/engine/base.py @@ -8,9 +8,9 @@ import os import sys import typing -from returnn.util.basic import BackendEngine, model_epoch_from_filename, get_model_filename_postfix from returnn.log import log from returnn.pretrain import Pretrain +from returnn.util import basic as util class EngineBase(object): @@ -52,7 +52,7 @@ def get_existing_models(cls, config): if os.path.exists(fn): file_list[epoch] = fn break - if BackendEngine.is_tensorflow_selected(): + if util.BackendEngine.is_tensorflow_selected(): if os.path.exists(fn + ".index"): file_list[epoch] = fn break @@ -72,32 +72,24 @@ def get_epoch_model(cls, config): start_epoch = int(start_epoch_mode) assert start_epoch >= 1 - load_model_epoch_filename = config.value('load', '') - if load_model_epoch_filename.endswith(".meta"): - load_model_epoch_filename = load_model_epoch_filename[:-len(".meta")] - elif load_model_epoch_filename.endswith(".index"): - load_model_epoch_filename = load_model_epoch_filename[:-len(".index")] + load_model_epoch_filename = util.get_checkpoint_filepattern(config.value('load', '')) if load_model_epoch_filename: - assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix()) + assert os.path.exists(load_model_epoch_filename + util.get_model_filename_postfix()) - import_model_train_epoch1 = config.value('import_model_train_epoch1', '') - if import_model_train_epoch1.endswith(".meta"): - import_model_train_epoch1 = import_model_train_epoch1[:-len(".meta")] - elif import_model_train_epoch1.endswith(".index"): - import_model_train_epoch1 = import_model_train_epoch1[:-len(".index")] + import_model_train_epoch1 = util.get_checkpoint_filepattern(config.value('import_model_train_epoch1', '')) if import_model_train_epoch1: - assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix()) + assert os.path.exists(import_model_train_epoch1 + util.get_model_filename_postfix()) existing_models = cls.get_existing_models(config) load_epoch = config.int("load_epoch", -1) if load_model_epoch_filename: if load_epoch <= 0: - load_epoch = model_epoch_from_filename(load_model_epoch_filename) + load_epoch = util.model_epoch_from_filename(load_model_epoch_filename) else: if load_epoch > 0: # ignore if load_epoch == 0 assert load_epoch in existing_models load_model_epoch_filename = existing_models[load_epoch] - assert model_epoch_from_filename(load_model_epoch_filename) == load_epoch + assert util.model_epoch_from_filename(load_model_epoch_filename) == load_epoch # Only use this when we don't train. # For training, we first consider existing models before we take the 'load' into account when in auto epoch mode. diff --git a/returnn/tf/network.py b/returnn/tf/network.py index c5a26e524..326bb9b2a 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -16,6 +16,7 @@ import returnn.tf.compat as tf_compat import returnn.tf.util.basic as tf_util from returnn.tf.util.basic import Data, DimensionTag, reuse_name_scope, VariableAssigner +from returnn.util import basic as util class DataNotFound(Exception): @@ -3473,7 +3474,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix=" ignore_params=(), ignore_params_prefixes=(), var_name_mapping=None, network=None): """ - :param str filename: filepattern for NewCheckpointReader + :param str filename: filepattern for NewCheckpointReader or .index/.meta file path :param list[tf.Variable|tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject] saveable_params: :param str params_prefix: expect that all vars in saveable_params have this prefix, and remove it :param str load_if_prefix: if given, only load variables with a name containing this string. @@ -3486,7 +3487,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix=" renamed vars in the checkpoint :param TFNetwork network: """ - self.filename = filename + self.filepattern = util.get_checkpoint_filepattern(filename) self.network = network self.ignore_missing = ignore_missing self.params_prefix = params_prefix @@ -3510,7 +3511,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix=" continue self.saveable_params.append(param) assert count > 0, "%s: no saveable vars" % self - self.reader = tf_compat.v1.train.NewCheckpointReader(filename) + self.reader = tf_compat.v1.train.NewCheckpointReader(self.filepattern) self.net_vars = [v for v in self.saveable_params if isinstance(v, tf.Variable)] self.net_saveables = [v for v in self.saveable_params if not isinstance(v, tf.Variable)] # All variables in the checkpoint: @@ -3918,7 +3919,7 @@ def get_lazy_dict(self): if self.ignore_missing and v_name not in var_name_map: print( "Warning, did not find match for var %r (%r, params_prefix %r, load_if_prefix %r) in checkpoint %r." % ( - v, v_name, self.params_prefix, self.load_if_prefix, self.filename), file=log.v3) + v, v_name, self.params_prefix, self.load_if_prefix, self.filepattern), file=log.v3) continue variable_values[v] = self.VariableValue(value=var_name_map[v_name]()) assert variable_values, "no vars to load; saveable vars are %r. load_if_prefix %r." % ( diff --git a/returnn/util/basic.py b/returnn/util/basic.py index 3e872b6d6..8bee2e2a5 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -288,6 +288,21 @@ def get_model_filename_postfix(): return "" +def get_checkpoint_filepattern(filepath): + """ + Removes optional .index or .meta extension + + :param str filepath: + :return: CheckpointLoader compatible filepattern + :rtype: str + """ + if filepath.endswith(".meta"): + return filepath[:-len(".meta")] + elif filepath.endswith(".index"): + return filepath[:-len(".index")] + return filepath + + def sys_cmd_out_lines(s): """ :param str s: shell command