Skip to content

Commit cff8e17

Browse files
k-w-wserach24
authored andcommitted
Save Keras metadata in a separate proto and raise deprecation warnings when loading a SavedModel with tf.saved_model.save().
PiperOrigin-RevId: 339760831 Change-Id: I8980807eb4f2f0f1a8c4420b7e4c386842f5ebf9
1 parent 13563e8 commit cff8e17

File tree

6 files changed

+117
-17
lines changed

6 files changed

+117
-17
lines changed

tensorflow/python/keras/saving/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ py_library(
4949
deps = [
5050
"//tensorflow/python:lib",
5151
"//tensorflow/python:math_ops",
52+
"//tensorflow/python:platform",
5253
"//tensorflow/python:saver",
5354
"//tensorflow/python:tensor_spec",
5455
"//tensorflow/python/eager:def_function",

tensorflow/python/keras/saving/saved_model/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@
2626
# Keys for the serialization cache.
2727
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
2828
KERAS_CACHE_KEY = 'keras_serialized_attributes'
29+
30+
31+
# Name of Keras metadata file stored in the SavedModel.
32+
SAVED_METADATA_PATH = 'keras_metadata.pb'

tensorflow/python/keras/saving/saved_model/load.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import os
2021
import re
2122
import types
2223

24+
from google.protobuf import message
25+
2326
from tensorflow.core.framework import versions_pb2
2427
from tensorflow.python.eager import context
2528
from tensorflow.python.eager import function as defun
@@ -38,6 +41,7 @@
3841
from tensorflow.python.keras.utils import generic_utils
3942
from tensorflow.python.keras.utils import metrics_utils
4043
from tensorflow.python.keras.utils.generic_utils import LazyLoader
44+
from tensorflow.python.platform import gfile
4145
from tensorflow.python.platform import tf_logging as logging
4246
from tensorflow.python.saved_model import load as tf_load
4347
from tensorflow.python.saved_model import loader_impl
@@ -121,13 +125,26 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
121125
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
122126
# TODO(kathywu): Add code to load from objects that contain all endpoints
123127

124-
# The Keras metadata file is not yet saved, so create it from the SavedModel.
128+
# Look for metadata file or parse the SavedModel
125129
metadata = saved_metadata_pb2.SavedMetadata()
126130
meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
127131
object_graph_def = meta_graph_def.object_graph_def
128-
# TODO(kathywu): When the keras metadata file is saved, load it directly
129-
# instead of calling the _read_legacy_metadata function.
130-
_read_legacy_metadata(object_graph_def, metadata)
132+
path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
133+
if gfile.Exists(path_to_metadata_pb):
134+
try:
135+
with gfile.GFile(path_to_metadata_pb, 'rb') as f:
136+
file_content = f.read()
137+
metadata.ParseFromString(file_content)
138+
except message.DecodeError as e:
139+
raise IOError('Cannot parse keras metadata {}: {}.'
140+
.format(path_to_metadata_pb, str(e)))
141+
else:
142+
logging.warning('SavedModel saved prior to TF 2.4 detected when loading '
143+
'Keras model. Please ensure that you are saving the model '
144+
'with model.save() or tf.keras.models.save_model(), *NOT* '
145+
'tf.saved_model.save(). To confirm, there should be a file '
146+
'named "keras_metadata.pb" in the SavedModel directory.')
147+
_read_legacy_metadata(object_graph_def, metadata)
131148

132149
if not metadata.nodes:
133150
# When there are no Keras objects, return the results from the core loader

tensorflow/python/keras/saving/saved_model/save.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@
1818
from __future__ import print_function
1919

2020
import os
21+
22+
from tensorflow.core.framework import versions_pb2
2123
from tensorflow.python.distribute import distribution_strategy_context
2224
from tensorflow.python.keras import backend as K
25+
from tensorflow.python.keras.protobuf import saved_metadata_pb2
2326
from tensorflow.python.keras.saving import saving_utils
27+
from tensorflow.python.keras.saving.saved_model import constants
2428
from tensorflow.python.keras.saving.saved_model import save_impl
2529
from tensorflow.python.keras.saving.saved_model import utils
2630
from tensorflow.python.keras.utils.generic_utils import LazyLoader
2731
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
32+
from tensorflow.python.platform import gfile
2833
from tensorflow.python.saved_model import save as save_lib
2934

35+
3036
# To avoid circular dependencies between keras/engine and keras/saving,
3137
# code in keras/saving must delay imports.
3238

@@ -86,7 +92,39 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
8692
# we use the default replica context here.
8793
with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
8894
with utils.keras_option_scope(save_traces):
89-
save_lib.save(model, filepath, signatures, options)
95+
saved_nodes, node_paths = save_lib.save_and_return_nodes(
96+
model, filepath, signatures, options)
97+
98+
# Save all metadata to a separate file in the SavedModel directory.
99+
metadata = generate_keras_metadata(saved_nodes, node_paths)
100+
101+
with gfile.GFile(
102+
os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w:
103+
w.write(metadata.SerializeToString(deterministic=True))
90104

91105
if not include_optimizer:
92106
model.optimizer = orig_optimizer
107+
108+
109+
def generate_keras_metadata(saved_nodes, node_paths):
110+
"""Constructs a KerasMetadata proto with the metadata of each keras object."""
111+
metadata = saved_metadata_pb2.SavedMetadata()
112+
113+
for node_id, node in enumerate(saved_nodes):
114+
if isinstance(node, base_layer.Layer):
115+
path = node_paths[node]
116+
if not path:
117+
node_path = "root"
118+
else:
119+
node_path = "root.{}".format(
120+
".".join([ref.name for ref in path]))
121+
122+
metadata.nodes.add(
123+
node_id=node_id,
124+
node_path=node_path,
125+
version=versions_pb2.VersionDef(
126+
producer=1, min_consumer=1, bad_consumers=[]),
127+
identifier=node._object_identifier, # pylint: disable=protected-access
128+
metadata=node._tracking_metadata) # pylint: disable=protected-access
129+
130+
return metadata

tensorflow/python/saved_model/save.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,9 @@ def __init__(self, checkpoint_view, options, wrapped_functions=None):
183183
"""
184184
self.options = options
185185
self.checkpoint_view = checkpoint_view
186-
trackable_objects, node_ids, slot_variables = (
187-
self.checkpoint_view.objects_ids_and_slot_variables())
186+
trackable_objects, path_to_root, node_ids, slot_variables = (
187+
self.checkpoint_view.objects_ids_and_slot_variables_and_paths())
188+
self.node_paths = path_to_root
188189
self.nodes = trackable_objects
189190
self.node_ids = node_ids
190191
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
@@ -1029,15 +1030,40 @@ def serve():
10291030
May not be called from within a function body.
10301031
@end_compatibility
10311032
"""
1033+
save_and_return_nodes(obj, export_dir, signatures, options,
1034+
raise_metadata_warning=True)
1035+
1036+
1037+
def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
1038+
raise_metadata_warning=False):
1039+
"""Saves a SavedModel while returning all saved nodes and their paths.
1040+
1041+
Please see `tf.saved_model.save` for details.
1042+
1043+
Args:
1044+
obj: A trackable object to export.
1045+
export_dir: A directory in which to write the SavedModel.
1046+
signatures: A function or dictionary of functions to save in the SavedModel
1047+
as signatures.
1048+
options: `tf.saved_model.SaveOptions` object for configuring save options.
1049+
raise_metadata_warning: Whether to raise the metadata warning. This arg will
1050+
be removed in TF 2.5.
1051+
1052+
Returns:
1053+
A tuple of (a list of saved nodes in the order they are serialized to the
1054+
`SavedObjectGraph`, dictionary mapping nodes to one possible path from
1055+
the root node to the key node)
1056+
"""
10321057
options = options or save_options.SaveOptions()
10331058
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
10341059
# compatible (no sessions) and share it with this export API rather than
10351060
# making a SavedModel proto and writing it directly.
10361061
saved_model = saved_model_pb2.SavedModel()
10371062
meta_graph_def = saved_model.meta_graphs.add()
10381063

1039-
_, exported_graph, object_saver, asset_info = _build_meta_graph(
1040-
obj, signatures, options, meta_graph_def)
1064+
_, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
1065+
_build_meta_graph(obj, signatures, options, meta_graph_def,
1066+
raise_metadata_warning))
10411067
saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
10421068

10431069
# Write the checkpoint, copy assets into the assets directory, and write out
@@ -1077,6 +1103,8 @@ def serve():
10771103
# constants in the saved graph.
10781104
ops.dismantle_graph(exported_graph)
10791105

1106+
return saved_nodes, node_paths
1107+
10801108

10811109
def export_meta_graph(obj, filename, signatures=None, options=None):
10821110
"""Exports the MetaGraph proto of the `obj` to a file.
@@ -1103,7 +1131,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
11031131
"""
11041132
options = options or save_options.SaveOptions()
11051133
export_dir = os.path.dirname(filename)
1106-
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
1134+
meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph(
11071135
obj, signatures, options)
11081136

11091137
file_io.atomic_write_string_to_file(
@@ -1122,7 +1150,8 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
11221150
def _build_meta_graph_impl(obj,
11231151
signatures,
11241152
options,
1125-
meta_graph_def=None):
1153+
meta_graph_def=None,
1154+
raise_metadata_warning=True):
11261155
"""Creates a MetaGraph containing the resources and functions of an object."""
11271156
if ops.inside_function():
11281157
raise AssertionError(
@@ -1170,7 +1199,7 @@ def _build_meta_graph_impl(obj,
11701199
saveable_view, asset_info.asset_index)
11711200
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
11721201

1173-
if saved_object_metadata:
1202+
if saved_object_metadata and raise_metadata_warning:
11741203
tf_logging.warn(
11751204
'FOR KERAS USERS: The object that you are saving contains one or more '
11761205
'Keras models or layers. If you are loading the SavedModel with '
@@ -1186,13 +1215,15 @@ def _build_meta_graph_impl(obj,
11861215
'metadta field will be deprecated soon, so please move the metadata to '
11871216
'a different file.')
11881217

1189-
return (meta_graph_def, exported_graph, object_saver, asset_info)
1218+
return (meta_graph_def, exported_graph, object_saver, asset_info,
1219+
saveable_view.nodes, saveable_view.node_paths)
11901220

11911221

11921222
def _build_meta_graph(obj,
11931223
signatures,
11941224
options,
1195-
meta_graph_def=None):
1225+
meta_graph_def=None,
1226+
raise_metadata_warning=True):
11961227
"""Creates a MetaGraph under a save context.
11971228
11981229
Args:
@@ -1205,6 +1236,8 @@ def _build_meta_graph(obj,
12051236
options: `tf.saved_model.SaveOptions` object that specifies options for
12061237
saving.
12071238
meta_graph_def: Optional, the MetaGraphDef proto fill.
1239+
raise_metadata_warning: Whether to raise a warning when user objects contain
1240+
non-empty metadata.
12081241
12091242
Raises:
12101243
AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
@@ -1218,4 +1251,5 @@ def _build_meta_graph(obj,
12181251
"""
12191252

12201253
with save_context.save_context(options):
1221-
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
1254+
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
1255+
raise_metadata_warning)

tensorflow/python/training/tracking/graph_view.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def frozen_saveable_objects(self, object_map=None, to_graph=None,
430430
name=base.OBJECT_GRAPH_PROTO_KEY))
431431
return named_saveable_objects
432432

433-
def objects_ids_and_slot_variables(self):
433+
def objects_ids_and_slot_variables_and_paths(self):
434434
"""Traverse the object graph and list all accessible objects.
435435
436436
Looks for `Trackable` objects which are dependencies of
@@ -439,7 +439,8 @@ def objects_ids_and_slot_variables(self):
439439
(i.e. if they would be saved with a checkpoint).
440440
441441
Returns:
442-
A tuple of (trackable objects, object -> node id, slot variables)
442+
A tuple of (trackable objects, paths from root for each object,
443+
object -> node id, slot variables)
443444
"""
444445
trackable_objects, path_to_root = self._breadth_first_traversal()
445446
object_names = object_identity.ObjectIdentityDictionary()
@@ -452,6 +453,11 @@ def objects_ids_and_slot_variables(self):
452453
trackable_objects=trackable_objects,
453454
node_ids=node_ids,
454455
object_names=object_names)
456+
return trackable_objects, path_to_root, node_ids, slot_variables
457+
458+
def objects_ids_and_slot_variables(self):
459+
trackable_objects, _, node_ids, slot_variables = (
460+
self.objects_ids_and_slot_variables_and_paths())
455461
return trackable_objects, node_ids, slot_variables
456462

457463
def list_objects(self):

0 commit comments

Comments
 (0)