Skip to content

Commit 98bc853

Browse files
committed
consistent checkpoint filepattern
1 parent 3f1554d commit 98bc853

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

returnn/engine/base.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from returnn.util.basic import BackendEngine, model_epoch_from_filename, get_model_filename_postfix
1212
from returnn.log import log
1313
from returnn.pretrain import Pretrain
14+
from returnn.tf.util.basic import get_checkpoint_filepattern
1415

1516

1617
class EngineBase(object):
@@ -72,19 +73,11 @@ def get_epoch_model(cls, config):
7273
start_epoch = int(start_epoch_mode)
7374
assert start_epoch >= 1
7475

75-
load_model_epoch_filename = config.value('load', '')
76-
if load_model_epoch_filename.endswith(".meta"):
77-
load_model_epoch_filename = load_model_epoch_filename[:-len(".meta")]
78-
elif load_model_epoch_filename.endswith(".index"):
79-
load_model_epoch_filename = load_model_epoch_filename[:-len(".index")]
76+
load_model_epoch_filename = get_checkpoint_filepattern(config.value('load', ''))
8077
if load_model_epoch_filename:
8178
assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix())
8279

83-
import_model_train_epoch1 = config.value('import_model_train_epoch1', '')
84-
if import_model_train_epoch1.endswith(".meta"):
85-
import_model_train_epoch1 = import_model_train_epoch1[:-len(".meta")]
86-
elif import_model_train_epoch1.endswith(".index"):
87-
import_model_train_epoch1 = import_model_train_epoch1[:-len(".index")]
80+
import_model_train_epoch1 = get_checkpoint_filepattern(config.value('import_model_train_epoch1', ''))
8881
if import_model_train_epoch1:
8982
assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix())
9083

returnn/tf/network.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from returnn.tf.layers.basic import LayerBase, get_layer_class
1616
import returnn.tf.compat as tf_compat
1717
import returnn.tf.util.basic as tf_util
18-
from returnn.tf.util.basic import Data, DimensionTag, reuse_name_scope, VariableAssigner
18+
from returnn.tf.util.basic import Data, DimensionTag, reuse_name_scope, VariableAssigner, get_checkpoint_filepattern
1919

2020

2121
class DataNotFound(Exception):
@@ -3473,7 +3473,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34733473
ignore_params=(), ignore_params_prefixes=(), var_name_mapping=None,
34743474
network=None):
34753475
"""
3476-
:param str filename: filepattern for NewCheckpointReader
3476+
:param str filename: filepattern for NewCheckpointReader or .index/.meta file path
34773477
:param list[tf.Variable|tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject] saveable_params:
34783478
:param str params_prefix: expect that all vars in saveable_params have this prefix, and remove it
34793479
:param str load_if_prefix: if given, only load variables with a name containing this string.
@@ -3486,7 +3486,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
34863486
renamed vars in the checkpoint
34873487
:param TFNetwork network:
34883488
"""
3489-
self.filename = filename
3489+
self.filepattern = get_checkpoint_filepattern(filename)
34903490
self.network = network
34913491
self.ignore_missing = ignore_missing
34923492
self.params_prefix = params_prefix
@@ -3510,7 +3510,7 @@ def __init__(self, filename, saveable_params, params_prefix="", load_if_prefix="
35103510
continue
35113511
self.saveable_params.append(param)
35123512
assert count > 0, "%s: no saveable vars" % self
3513-
self.reader = tf_compat.v1.train.NewCheckpointReader(filename)
3513+
self.reader = tf_compat.v1.train.NewCheckpointReader(self.filepattern)
35143514
self.net_vars = [v for v in self.saveable_params if isinstance(v, tf.Variable)]
35153515
self.net_saveables = [v for v in self.saveable_params if not isinstance(v, tf.Variable)]
35163516
# All variables in the checkpoint:
@@ -3918,7 +3918,7 @@ def get_lazy_dict(self):
39183918
if self.ignore_missing and v_name not in var_name_map:
39193919
print(
39203920
"Warning, did not find match for var %r (%r, params_prefix %r, load_if_prefix %r) in checkpoint %r." % (
3921-
v, v_name, self.params_prefix, self.load_if_prefix, self.filename), file=log.v3)
3921+
v, v_name, self.params_prefix, self.load_if_prefix, self.filepattern), file=log.v3)
39223922
continue
39233923
variable_values[v] = self.VariableValue(value=var_name_map[v_name]())
39243924
assert variable_values, "no vars to load; saveable vars are %r. load_if_prefix %r." % (

returnn/tf/util/basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,21 @@ def have_min_tf_version(version):
5959
return tf_version >= version
6060

6161

62+
def get_checkpoint_filepattern(filepath):
63+
"""
64+
Removes optional .index or .meta extension
65+
66+
:param str filepath:
67+
:return: CheckpointLoader compatible filepattern
68+
:rtype: str
69+
"""
70+
if filepath.endswith(".meta"):
71+
return filepath[:-len(".meta")]
72+
elif filepath.endswith(".index"):
73+
return filepath[:-len(".index")]
74+
return filepath
75+
76+
6277
class CustomUpdate(object):
6378
"""
6479
Custom updates will be handled by :class:`TFUpdater`.

0 commit comments

Comments
 (0)