/
loader_impl.py
507 lines (418 loc) · 19.8 KB
/
loader_impl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loader implementation for SavedModel with hermetic, language-neutral exports.
"""
import os
import sys
from google.protobuf import message
from google.protobuf import text_format
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.saved_model.pywrap_saved_model import metrics
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# API label for SavedModel metrics.
_LOADER_LABEL = "loader"
def parse_saved_model_with_debug_info(export_dir):
"""Reads the savedmodel as well as the graph debug info.
Args:
export_dir: Directory containing the SavedModel and GraphDebugInfo files.
Returns:
`SavedModel` and `GraphDebugInfo` protocol buffers.
Raises:
IOError: If the saved model file does not exist, or cannot be successfully
parsed. Missing graph debug info file is fine.
"""
saved_model = parse_saved_model(export_dir)
debug_info_path = file_io.join(
saved_model_utils.get_debug_dir(export_dir),
constants.DEBUG_INFO_FILENAME_PB)
debug_info = graph_debug_info_pb2.GraphDebugInfo()
if file_io.file_exists(debug_info_path):
with file_io.FileIO(debug_info_path, "rb") as debug_file:
try:
debug_info.ParseFromString(debug_file.read())
except message.DecodeError as e:
raise IOError(f"Cannot parse file {debug_info_path}: {e}.")
return (saved_model, debug_info)
@tf_export("__internal__.saved_model.parse_saved_model", v1=[])
def parse_saved_model(export_dir):
"""Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.
Args:
export_dir: String or Pathlike, path to the directory containing the
SavedModel file.
Returns:
A `SavedModel` protocol buffer.
Raises:
IOError: If the file does not exist, or cannot be successfully parsed.
"""
# Build the path to the SavedModel in pbtxt format.
path_to_pbtxt = file_io.join(
compat.as_bytes(compat.path_to_str(export_dir)),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
# Build the path to the SavedModel in pb format.
path_to_pb = file_io.join(
compat.as_bytes(compat.path_to_str(export_dir)),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
# Parse the SavedModel protocol buffer.
saved_model = saved_model_pb2.SavedModel()
if file_io.file_exists(path_to_pb):
with file_io.FileIO(path_to_pb, "rb") as f:
file_content = f.read()
try:
saved_model.ParseFromString(file_content)
return saved_model
except message.DecodeError as e:
raise IOError(f"Cannot parse file {path_to_pb}: {str(e)}.")
elif file_io.file_exists(path_to_pbtxt):
with file_io.FileIO(path_to_pbtxt, "rb") as f:
file_content = f.read()
try:
text_format.Merge(file_content.decode("utf-8"), saved_model)
return saved_model
except text_format.ParseError as e:
raise IOError(f"Cannot parse file {path_to_pbtxt}: {str(e)}.")
else:
raise IOError(
f"SavedModel file does not exist at: {export_dir}{os.path.sep}"
f"{{{constants.SAVED_MODEL_FILENAME_PBTXT}|"
f"{constants.SAVED_MODEL_FILENAME_PB}}}")
def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
"""Gets the asset tensors, if defined in the meta graph def to load.
Args:
export_dir: Directory where the SavedModel is located.
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
import_scope: Optional `string` -- if specified, prepend this followed by
'/' to all returned asset tensor names.
Returns:
A dictionary of asset tensors, keyed by the name of the asset tensor. The
value in the map corresponds to the absolute path of the asset file.
"""
# Collection-def that may contain the assets key.
collection_def = meta_graph_def_to_load.collection_def
asset_tensor_dict = {}
asset_protos = []
if meta_graph_def_to_load.asset_file_def:
asset_protos = meta_graph_def_to_load.asset_file_def
elif constants.ASSETS_KEY in collection_def:
assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
for asset_any_proto in assets_any_proto:
asset_proto = meta_graph_pb2.AssetFileDef()
asset_any_proto.Unpack(asset_proto)
asset_protos.append(asset_proto)
# Location of the assets for SavedModel.
assets_directory = file_io.join(
compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY))
# Process each asset and add it to the asset tensor dictionary.
for asset_proto in asset_protos:
tensor_name = asset_proto.tensor_info.name
if import_scope:
tensor_name = "%s/%s" % (import_scope, tensor_name)
asset_tensor_dict[tensor_name] = file_io.join(
compat.as_bytes(assets_directory),
compat.as_bytes(asset_proto.filename))
return asset_tensor_dict
def _get_main_op_tensor(
meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY):
"""Gets the main op tensor, if one exists.
Args:
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
init_op_key: name of the collection to check; should be one of MAIN_OP_KEY
or the deprecated LEGACY_INIT_OP_KEY
Returns:
The main op tensor, if it exists and `None` otherwise.
Raises:
RuntimeError: If the collection def corresponding to the main op key has
other than exactly one tensor.
"""
# TODO(kathywu): Rename this method to _get_op_from_collection when
# dependency from SavedModelEstimator is removed.
collection_def = meta_graph_def_to_load.collection_def
init_op = None
if init_op_key in collection_def:
init_op_list = collection_def[init_op_key].node_list.value
if len(init_op_list) != 1:
raise RuntimeError("Expected exactly one SavedModel init op. "
f"Found {len(init_op_list)}: {init_op_list}.")
init_op = ops.get_collection(init_op_key)[0]
return init_op
def _get_op_from_collection(meta_graph_def, op_key):
return _get_main_op_tensor(meta_graph_def, op_key)
def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope):
"""Retrieve op stored in the imported meta graph's signature def."""
if op_signature_key in meta_graph_def.signature_def:
return signature_def_utils.load_op_from_signature_def(
meta_graph_def.signature_def[op_signature_key], op_signature_key,
import_scope)
else:
return None
def get_init_op(meta_graph_def, import_scope=None):
return (_get_op_from_signature_def(
meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or
_get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or
_get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
def get_train_op(meta_graph_def, import_scope=None):
train_op = _get_op_from_signature_def(
meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope)
if train_op is None:
train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY)
return train_op
@tf_export(v1=[
"saved_model.contains_saved_model",
"saved_model.maybe_saved_model_directory",
"saved_model.loader.maybe_saved_model_directory"
])
@deprecation.deprecated_endpoints(
"saved_model.loader.maybe_saved_model_directory")
def maybe_saved_model_directory(export_dir):
"""Checks whether the provided export directory could contain a SavedModel.
Note that the method does not load any data by itself. If the method returns
`false`, the export directory definitely does not contain a SavedModel. If the
method returns `true`, the export directory may contain a SavedModel but
provides no guarantee that it can be loaded.
Args:
export_dir: Absolute string path to possible export location. For example,
'/my/foo/model'.
Returns:
True if the export directory contains SavedModel files, False otherwise.
"""
txt_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
pb_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)
@tf_export("saved_model.contains_saved_model", v1=[])
def contains_saved_model(export_dir):
"""Checks whether the provided export directory could contain a SavedModel.
Note that the method does not load any data by itself. If the method returns
`false`, the export directory definitely does not contain a SavedModel. If the
method returns `true`, the export directory may contain a SavedModel but
provides no guarantee that it can be loaded.
Args:
export_dir: Absolute path to possible export location. For example,
'/my/foo/model'.
Returns:
True if the export directory contains SavedModel files, False otherwise.
"""
if isinstance(export_dir, os.PathLike):
export_dir = os.fspath(export_dir)
return maybe_saved_model_directory(export_dir)
@tf_export(v1=["saved_model.load", "saved_model.loader.load"])
@deprecation.deprecated(
None,
"Use `tf.saved_model.load` instead.")
def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
"""Loads the model from a SavedModel as specified by tags.
Args:
sess: The TensorFlow session to restore the variables.
tags: Set of string tags to identify the required MetaGraphDef. These should
correspond to the tags used when saving the variables using the
SavedModel `save()` API.
export_dir: Directory in which the SavedModel protocol buffer and variables
to be loaded are located.
import_scope: Optional `string` -- if specified, prepend this string
followed by '/' to all loaded tensor names. This scope is applied to
tensor instances loaded into the passed session, but it is *not* written
through to the static `MetaGraphDef` protocol buffer that is returned.
**saver_kwargs: Optional keyword arguments passed through to Saver.
Returns:
The `MetaGraphDef` protocol buffer loaded in the provided session. This
can be used to further extract signature-defs, collection-defs, etc.
Raises:
RuntimeError: MetaGraphDef associated with the tags cannot be found.
@compatibility(TF2)
`tf.compat.v1.saved_model.load` or `tf.compat.v1.saved_model.loader.load` is
not compatible with eager execution. Please use `tf.saved_model.load` instead
to load your model. You can refer to the [SavedModel guide]
(https://www.tensorflow.org/guide/saved_model) for more information as well as
"Importing SavedModels from TensorFlow 1.x" in the [`tf.saved_model.load`]
(https://www.tensorflow.org/api_docs/python/tf/saved_model/load) docstring.
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :-------------------- | :-------------- | :------------------------- |
| `sess` | Not supported | - |
| `tags` | `tags` | - |
| `export_dir` | `export_dir` | - |
| `import_scope` | Not supported | Name scopes are not needed.
: : : By default, variables are :
: : : associated with the loaded :
: : : object and function names :
: : : are deduped. :
| `saver_kwargs` | Not supported | - |
#### Before & After Usage Example
Before:
```
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
tf.compat.v1.saved_model.loader.load(sess, ["foo-tag"], export_dir)
```
After:
```
model = tf.saved_model.load(export_dir, tags=["foo-tag"])
```
@end_compatibility
"""
loader = SavedModelLoader(export_dir)
return loader.load(sess, tags, import_scope, **saver_kwargs)
class SavedModelLoader(object):
"""Load graphs and restore variable values from a `SavedModel`."""
def __init__(self, export_dir):
"""Creates a `SavedModelLoader`.
Args:
export_dir: Directory in which the SavedModel protocol buffer and
variables to be loaded are located.
"""
self._export_dir = export_dir
self._variables_path = saved_model_utils.get_variables_path(export_dir)
self._saved_model = parse_saved_model(export_dir)
@property
def export_dir(self):
"""Directory containing the SavedModel."""
return self._export_dir
@property
def variables_path(self):
"""Path to variable checkpoint files."""
return self._variables_path
@property
def saved_model(self):
"""SavedModel object parsed from the export directory."""
return self._saved_model
def get_meta_graph_def_from_tags(self, tags):
"""Return MetaGraphDef with the exact specified tags.
Args:
tags: A list or set of string tags that identify the MetaGraphDef.
Returns:
MetaGraphDef with the same tags.
Raises:
RuntimeError: if no metagraphs were found with the associated tags.
"""
found_match = False
available_tags = []
for meta_graph_def in self._saved_model.meta_graphs:
available_tags.append(set(meta_graph_def.meta_info_def.tags))
if set(meta_graph_def.meta_info_def.tags) == set(tags):
meta_graph_def_to_load = meta_graph_def
found_match = True
break
if not found_match:
raise RuntimeError(
f"MetaGraphDef associated with tags {str(tags).strip('[]')} "
"could not be found in SavedModel, with available tags "
f"'{available_tags}'. To inspect available tag-sets in"
" the SavedModel, please use the SavedModel CLI: `saved_model_cli`.")
return meta_graph_def_to_load
def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
"""Load ops and nodes from SavedModel MetaGraph into graph.
Args:
graph: tf.Graph object.
tags: a set of string tags identifying a MetaGraphDef.
import_scope: Optional `string` -- if specified, prepend this string
followed by '/' to all loaded tensor names. This scope is applied to
tensor instances loaded into the passed session, but it is *not* written
through to the static `MetaGraphDef` protocol buffer that is returned.
**saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
Returns:
A tuple of
* Saver defined by the MetaGraph, which can be used to restore the
variable values.
* List of `Operation`/`Tensor` objects returned from
`tf.import_graph_def` (may be `None`).
"""
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
if sys.byteorder == "big":
saved_model_utils.swap_function_tensor_content(meta_graph_def, "little",
"big")
with graph.as_default():
return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access
meta_graph_def, import_scope=import_scope, **saver_kwargs)
def restore_variables(self, sess, saver, import_scope=None):
"""Restore SavedModel variable values into the session.
Args:
sess: tf.compat.v1.Session to restore variable values.
saver: a tf.compat.v1.train.Saver object. Can be None if there are no
variables in graph. This may be the saver returned by the load_graph()
function, or a default `tf.compat.v1.train.Saver()`.
import_scope: Optional `string` -- if specified, prepend this string
followed by '/' to all loaded tensor names. This scope is applied to
tensor instances loaded into the passed session, but it is *not* written
through to the static `MetaGraphDef` protocol buffer that is returned.
Raises:
ValueError: if no saver was passed to the saver argument, and there are
variables in the graph.
"""
with sess.graph.as_default():
if (saver is None and
not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access
tf_logging.info("The specified SavedModel has no variables; no "
"checkpoints were restored.")
elif isinstance(saver, tf_saver.Saver):
saver.restore(sess, self._variables_path)
else:
raise ValueError(
"No tf.train.Saver object was passed to the function "
"`SavedModelLoader.restore_variables`. Since there are variables in"
" the graph, a saver is required.")
def run_init_ops(self, sess, tags, import_scope=None):
"""Run initialization ops defined in the `MetaGraphDef`.
Args:
sess: tf.compat.v1.Session to restore variable values.
tags: a set of string tags identifying a MetaGraphDef.
import_scope: Optional `string` -- if specified, prepend this string
followed by '/' to all loaded tensor names. This scope is applied to
tensor instances loaded into the passed session, but it is *not* written
through to the static `MetaGraphDef` protocol buffer that is returned.
"""
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
with sess.graph.as_default():
# Get asset tensors, if any.
asset_tensors_dictionary = get_asset_tensors(
self._export_dir, meta_graph_def, import_scope=import_scope)
init_op = get_init_op(meta_graph_def, import_scope)
if init_op is not None:
sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary)
def load(self, sess, tags, import_scope=None, **saver_kwargs):
"""Load the MetaGraphDef graph and restore variable values into the session.
Args:
sess: tf.compat.v1.Session to restore variable values.
tags: a set of string tags identifying a MetaGraphDef.
import_scope: Optional `string` -- if specified, prepend this string
followed by '/' to all loaded tensor names. This scope is applied to
tensor instances loaded into the passed session, but it is *not* written
through to the static `MetaGraphDef` protocol buffer that is returned.
**saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
Returns:
`MetagraphDef` proto of the graph that was loaded.
"""
saved_model_proto = parse_saved_model(self._export_dir)
metrics.IncrementReadApi(_LOADER_LABEL)
with sess.graph.as_default():
saver, _ = self.load_graph(sess.graph, tags, import_scope,
**saver_kwargs)
self.restore_variables(sess, saver, import_scope)
self.run_init_ops(sess, tags, import_scope)
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
if (len(saved_model_proto.meta_graphs) == 1 and
saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
metrics.IncrementRead(write_version="2")
else:
metrics.IncrementRead(write_version="1")
return meta_graph_def