Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions returnn/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import os
import sys
import typing
from returnn.util.basic import BackendEngine, model_epoch_from_filename, get_model_filename_postfix
from returnn.log import log
from returnn.pretrain import Pretrain
from returnn.util import basic as util


class EngineBase(object):
Expand Down Expand Up @@ -52,7 +52,7 @@ def get_existing_models(cls, config):
if os.path.exists(fn):
file_list[epoch] = fn
break
if BackendEngine.is_tensorflow_selected():
if util.BackendEngine.is_tensorflow_selected():
if os.path.exists(fn + ".index"):
file_list[epoch] = fn
break
Expand All @@ -72,32 +72,24 @@ def get_epoch_model(cls, config):
start_epoch = int(start_epoch_mode)
assert start_epoch >= 1

load_model_epoch_filename = config.value('load', '')
if load_model_epoch_filename.endswith(".meta"):
load_model_epoch_filename = load_model_epoch_filename[:-len(".meta")]
elif load_model_epoch_filename.endswith(".index"):
load_model_epoch_filename = load_model_epoch_filename[:-len(".index")]
load_model_epoch_filename = util.get_checkpoint_filepattern(config.value('load', ''))
if load_model_epoch_filename:
assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix())
assert os.path.exists(load_model_epoch_filename + util.get_model_filename_postfix())

import_model_train_epoch1 = config.value('import_model_train_epoch1', '')
if import_model_train_epoch1.endswith(".meta"):
import_model_train_epoch1 = import_model_train_epoch1[:-len(".meta")]
elif import_model_train_epoch1.endswith(".index"):
import_model_train_epoch1 = import_model_train_epoch1[:-len(".index")]
import_model_train_epoch1 = util.get_checkpoint_filepattern(config.value('import_model_train_epoch1', ''))
if import_model_train_epoch1:
assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix())
assert os.path.exists(import_model_train_epoch1 + util.get_model_filename_postfix())

existing_models = cls.get_existing_models(config)
load_epoch = config.int("load_epoch", -1)
if load_model_epoch_filename:
if load_epoch <= 0:
load_epoch = model_epoch_from_filename(load_model_epoch_filename)
load_epoch = util.model_epoch_from_filename(load_model_epoch_filename)
else:
if load_epoch > 0: # ignore if load_epoch == 0
assert load_epoch in existing_models
load_model_epoch_filename = existing_models[load_epoch]
assert model_epoch_from_filename(load_model_epoch_filename) == load_epoch
assert util.model_epoch_from_filename(load_model_epoch_filename) == load_epoch

# Only use this when we don't train.
# For training, we first consider existing models before we take the 'load' into account when in auto epoch mode.
Expand Down
9 changes: 5 additions & 4 deletions returnn/tf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import returnn.tf.compat as tf_compat
import returnn.tf.util.basic as tf_util
from returnn.tf.util.basic import Data, DimensionTag, reuse_name_scope, VariableAssigner
from returnn.util import basic as util


class DataNotFound(Exception):
Expand Down Expand Up @@ -3473,7 +3474,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
ignore_params=(), ignore_params_prefixes=(), var_name_mapping=None,
network=None):
"""
:param str filename: filepattern for NewCheckpointReader
:param str filename: filepattern for NewCheckpointReader or .index/.meta file path
:param list[tf.Variable|tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject] saveable_params:
:param str params_prefix: expect that all vars in saveable_params have this prefix, and remove it
:param str load_if_prefix: if given, only load variables with a name containing this string.
Expand All @@ -3486,7 +3487,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
renamed vars in the checkpoint
:param TFNetwork network:
"""
self.filename = filename
self.filepattern = util.get_checkpoint_filepattern(filename)
self.network = network
self.ignore_missing = ignore_missing
self.params_prefix = params_prefix
Expand All @@ -3510,7 +3511,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
continue
self.saveable_params.append(param)
assert count > 0, "%s: no saveable vars" % self
self.reader = tf_compat.v1.train.NewCheckpointReader(filename)
self.reader = tf_compat.v1.train.NewCheckpointReader(self.filepattern)
self.net_vars = [v for v in self.saveable_params if isinstance(v, tf.Variable)]
self.net_saveables = [v for v in self.saveable_params if not isinstance(v, tf.Variable)]
# All variables in the checkpoint:
Expand Down Expand Up @@ -3918,7 +3919,7 @@ def get_lazy_dict(self):
if self.ignore_missing and v_name not in var_name_map:
print(
"Warning, did not find match for var %r (%r, params_prefix %r, load_if_prefix %r) in checkpoint %r." % (
v, v_name, self.params_prefix, self.load_if_prefix, self.filename), file=log.v3)
v, v_name, self.params_prefix, self.load_if_prefix, self.filepattern), file=log.v3)
continue
variable_values[v] = self.VariableValue(value=var_name_map[v_name]())
assert variable_values, "no vars to load; saveable vars are %r. load_if_prefix %r." % (
Expand Down
15 changes: 15 additions & 0 deletions returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,21 @@ def get_model_filename_postfix():
return ""


def get_checkpoint_filepattern(filepath):
"""
Removes optional .index or .meta extension

:param str filepath:
:return: CheckpointLoader compatible filepattern
:rtype: str
"""
if filepath.endswith(".meta"):
return filepath[:-len(".meta")]
elif filepath.endswith(".index"):
return filepath[:-len(".index")]
return filepath


def sys_cmd_out_lines(s):
"""
:param str s: shell command
Expand Down