/
method_name_updater.py
143 lines (123 loc) · 5.86 KB
/
method_name_updater.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SignatureDef method name utility functions.
Utility functions for manipulating signature_def.method_names.
"""
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader_impl as loader
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
# TODO(jdchung): Consider integrated this into the saved_model_cli so that users
# could do this from the command line directly.
@tf_export(v1=["saved_model.signature_def_utils.MethodNameUpdater"])
class MethodNameUpdater(object):
"""Updates the method name(s) of the SavedModel stored in the given path.
The `MethodNameUpdater` class provides the functionality to update the method
name field in the signature_defs of the given SavedModel. For example, it
can be used to replace the `predict` `method_name` to `regress`.
Typical usages of the `MethodNameUpdater`
```python
...
updater = tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater(
export_dir)
# Update all signature_defs with key "foo" in all meta graph defs.
updater.replace_method_name(signature_key="foo", method_name="regress")
# Update a single signature_def with key "bar" in the meta graph def with
# tags ["serve"]
updater.replace_method_name(signature_key="bar", method_name="classify",
tags="serve")
updater.save(new_export_dir)
```
Note: This function will only be available through the v1 compatibility
library as tf.compat.v1.saved_model.builder.MethodNameUpdater.
"""
def __init__(self, export_dir):
"""Creates an MethodNameUpdater object.
Args:
export_dir: Directory containing the SavedModel files.
Raises:
IOError: If the saved model file does not exist, or cannot be successfully
parsed.
"""
self._export_dir = export_dir
self._saved_model = loader.parse_saved_model(export_dir)
def replace_method_name(self, signature_key, method_name, tags=None):
"""Replaces the method_name in the specified signature_def.
This will match and replace multiple sig defs iff tags is None (i.e when
multiple `MetaGraph`s have a signature_def with the same key).
If tags is not None, this will only replace a single signature_def in the
`MetaGraph` with matching tags.
Args:
signature_key: Key of the signature_def to be updated.
method_name: new method_name to replace the existing one.
tags: A tag or sequence of tags identifying the `MetaGraph` to update. If
None, all meta graphs will be updated.
Raises:
ValueError: if signature_key or method_name are not defined or
if no metagraphs were found with the associated tags or
if no meta graph has a signature_def that matches signature_key.
"""
if not signature_key:
raise ValueError("`signature_key` must be defined.")
if not method_name:
raise ValueError("`method_name` must be defined.")
if (tags is not None and not isinstance(tags, list)):
tags = [tags]
found_match = False
for meta_graph_def in self._saved_model.meta_graphs:
if tags is None or set(tags) == set(meta_graph_def.meta_info_def.tags):
if signature_key not in meta_graph_def.signature_def:
raise ValueError(
f"MetaGraphDef associated with tags {tags} "
f"does not have a signature_def with key: '{signature_key}'. "
"This means either you specified the wrong signature key or "
"forgot to put the signature_def with the corresponding key in "
"your SavedModel.")
meta_graph_def.signature_def[signature_key].method_name = method_name
found_match = True
if not found_match:
raise ValueError(
f"MetaGraphDef associated with tags {tags} could not be found in "
"SavedModel. This means either you specified invalid tags or your "
"SavedModel does not have a MetaGraphDef with the specified tags.")
def save(self, new_export_dir=None):
"""Saves the updated `SavedModel`.
Args:
new_export_dir: Path where the updated `SavedModel` will be saved. If
None, the input `SavedModel` will be overriden with the updates.
Raises:
errors.OpError: If there are errors during the file save operation.
"""
is_input_text_proto = file_io.file_exists(
file_io.join(
compat.as_bytes(self._export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)))
if not new_export_dir:
new_export_dir = self._export_dir
if is_input_text_proto:
# TODO(jdchung): Add a util for the path creation below.
path = file_io.join(
compat.as_bytes(new_export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
file_io.write_string_to_file(path, str(self._saved_model))
else:
path = file_io.join(
compat.as_bytes(new_export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
file_io.write_string_to_file(
path, self._saved_model.SerializeToString(deterministic=True))
tf_logging.info("SavedModel written to: %s", compat.as_text(path))