/
function_deserialization.py
416 lines (345 loc) · 16.2 KB
/
function_deserialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools for deserializing `Function`s."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from tensorflow.core.framework import function_pb2
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as function_lib
from tensorflow.python.framework import func_graph as func_graph_lib
from tensorflow.python.framework import function_def_to_graph as function_def_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
def _is_tensor(t):
return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
def _call_concrete_function(function, inputs):
"""Calls a restored Function with structured inputs.
This differs from `function.__call__` in that inputs and outputs are
structured and that it casts inputs to tensors if needed.
Note: this does not checks that non-tensor inputs match. That should be
done before via `_concrete_function_callable_with`.
Args:
function: ConcreteFunction to call.
inputs: Structured inputs compatible with
`function.graph.structured_input_signature`.
Returns:
The structured function output.
"""
expected_structure = function.graph.structured_input_signature
flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
tensor_inputs = []
for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
if isinstance(expected, tensor_spec.TensorSpec):
tensor_inputs.append(
ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
result = function._call_flat(tensor_inputs, function._captured_inputs) # pylint: disable=protected-access
if isinstance(result, ops.Operation):
return None
return result
def _try_convert_to_tensor_spec(arg, dtype_hint):
"""Returns None or TensorSpec obtained if `arg` is converted to tensor."""
try:
# Note: try conversion in a FuncGraph to avoid poluting current context.
with func_graph_lib.FuncGraph(name="guess_conversion").as_default():
result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint)
return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype)
except (TypeError, ValueError):
return None
def _concrete_function_callable_with(function, inputs, allow_conversion):
"""Returns whether concrete `function` can be called with `inputs`."""
expected_structure = function.graph.structured_input_signature
try:
flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
except (TypeError, ValueError):
return False
try:
# Verify that no input elements were dropped during flattening.
repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
# TODO(b/129422719): Namedtuple subclasses re-created through
# saved_model.load don't compare equal in type to the original in
# assert_same_structure. Fix that and we can take out check_types=False
# here.
nest.assert_same_structure(inputs, repacked, check_types=False)
except (TypeError, ValueError):
return False
for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
if isinstance(expected, tensor_spec.TensorSpec):
if allow_conversion:
arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
return False
if arg.dtype != expected.dtype:
return False
if not expected.shape.is_compatible_with(arg.shape):
return False
else:
if arg != expected:
return False
return True
def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
"""Deserialize a FunctionSpec object from its proto representation."""
typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
# Convert a method function into a non method.
if function_spec_proto.is_method:
if not typeless_fullargspec.args:
raise NotImplementedError(
"Missing support to deserialize a method function without a named "
"'self' argument.")
args = typeless_fullargspec.args[1:]
else:
args = typeless_fullargspec.args
fullargspec = tf_inspect.FullArgSpec(
args=args,
varargs=typeless_fullargspec.varargs,
varkw=typeless_fullargspec.varkw,
defaults=typeless_fullargspec.defaults,
kwonlyargs=typeless_fullargspec.kwonlyargs,
kwonlydefaults=typeless_fullargspec.kwonlydefaults,
annotations=typeless_fullargspec.annotations)
input_signature = coder.decode_proto(function_spec_proto.input_signature)
return function_lib.FunctionSpec(fullargspec=fullargspec,
is_method=False,
args_to_prepend=[],
kwargs_to_include={},
input_signature=input_signature)
# TODO(allenl): The fact that we can't derive ConcreteFunction calling
# conventions from the serialized input spec right now is unfortunate. Merging
# these would be good, maybe by adding TensorSpec names to cache keys so renamed
# keyword arguments would yield different ConcreteFunctions.
def setup_bare_concrete_function(saved_bare_concrete_function,
concrete_functions):
"""Makes a restored bare concrete function callable."""
# Bare concrete functions accept only flat lists of Tensors with unique
# names.
concrete_function = concrete_functions[
saved_bare_concrete_function.concrete_function_name]
# pylint: disable=protected-access
concrete_function._arg_keywords = (
saved_bare_concrete_function.argument_keywords)
concrete_function._num_positional_args = (
saved_bare_concrete_function.allowed_positional_arguments)
# pylint: enable=protected-access
concrete_function.add_to_graph()
return concrete_function
class RestoredFunction(def_function.Function):
"""Wrapper class for a function that has been restored from saved state.
See `def_function.Function`.
"""
def __init__(self, python_function, name, function_spec, concrete_functions):
# TODO(mdan): We may enable autograph once exceptions are supported.
super(RestoredFunction, self).__init__(
python_function, name, autograph=False)
self.concrete_functions = concrete_functions
self._function_spec = function_spec
def _list_all_concrete_functions_for_serialization(self):
return self.concrete_functions
def _defun_with_scope(self, scope):
func = super(RestoredFunction, self)._defun_with_scope(scope)
func._function_spec = self._function_spec # pylint: disable=protected-access
return func
def recreate_function(saved_function, concrete_functions):
"""Creates a `Function` from a `SavedFunction`.
Args:
saved_function: `SavedFunction` proto.
concrete_functions: map from function name to `ConcreteFunction`.
Returns:
A `Function`.
"""
# TODO(andresp): Construct a `Function` with the cache populated
# instead of creating a new `Function` backed by a Python layer to
# glue things together. Current approach is nesting functions deeper for each
# serialization cycle.
coder = nested_structure_coder.StructureCoder()
# Note: handling method functions is tricky since make_decorator does not
# allows control of "ismethod". Additionally since restored functions do
# not behave as methods i.e. they always use the same captured tensors
# independent of the object they are bound to, there is little value on
# propagating that correctly.
#
# Ideally this conversion should happen at serialization time. But since
# there are SavedModels which have "ismethod" populated and have an extra
# argument that they expect to be ignored, we do it at deserialization.
function_spec = _deserialize_function_spec_as_nonmethod(
saved_function.function_spec,
coder)
def restored_function_body(*args, **kwargs):
"""Calls a restored function."""
# This is the format of function.graph.structured_input_signature. At this
# point, the args and kwargs have already been canonicalized.
inputs = (args, kwargs)
# First try to find a concrete function that can be called without input
# conversions. This allows one to pick a more specific trace in case there
# was also a more expensive one that supported tensors.
for allow_conversion in [False, True]:
for function_name in saved_function.concrete_functions:
function = concrete_functions[function_name]
if _concrete_function_callable_with(function, inputs, allow_conversion):
return _call_concrete_function(function, inputs)
signature_descriptions = []
def _pretty_format_positional(positional):
return "Positional arguments ({} total):\n * {}".format(
len(positional),
"\n * ".join([str(a) for a in positional]))
for index, function_name in enumerate(saved_function.concrete_functions):
concrete_function = concrete_functions[function_name]
positional, keyword = concrete_function.structured_input_signature
signature_descriptions.append(
"Option {}:\n {}\n Keyword arguments: {}"
.format(index + 1, _pretty_format_positional(positional), keyword))
raise ValueError(
"Could not find matching function to call loaded from the SavedModel. "
"Got:\n {}\n Keyword arguments: {}\n\nExpected "
"these arguments to match one of the following {} option(s):\n\n{}"
.format(_pretty_format_positional(args), kwargs,
len(saved_function.concrete_functions),
"\n\n".join(signature_descriptions)))
concrete_function_objects = []
for concrete_function_name in saved_function.concrete_functions:
concrete_function_objects.append(concrete_functions[concrete_function_name])
restored_function = RestoredFunction(
restored_function_body,
restored_function_body.__name__,
function_spec,
concrete_function_objects)
return tf_decorator.make_decorator(
restored_function_body,
restored_function,
decorator_argspec=function_spec.fullargspec)
def load_function_def_library(library):
"""Load a set of functions as concrete functions without captured inputs.
Functions names are manipulated during load such that they do not overlap
with previously created ones.
Args:
library: FunctionDefLibrary proto message.
Returns:
Map of original function names in the library to instances of
`ConcreteFunction` without captured inputs.
Raises:
ValueError: if functions dependencies have a cycle.
"""
functions = {}
load_shared_name_suffix = "_load_{}".format(ops.uid())
for fdef in _sort_function_defs(library):
copy = _fix_fdef(fdef, functions, load_shared_name_suffix)
# There is no need to copy functions into the function def graph.
# It leads to a O(n^2) increase of memory when importing functions
# and the extra function definitions are a no-op since they already
# imported as a function before (due to the topologic sort import).
func_graph = function_def_lib.function_def_to_graph(
copy, copy_functions=False)
for dep in _list_function_deps(fdef):
functions[dep].add_to_graph(func_graph)
func = function_lib.ConcreteFunction(func_graph)
func.add_to_graph()
functions[fdef.signature.name] = func
# Also register the gradients in the current root context.
with ops.init_scope():
func._register_gradient() # pylint: disable=protected-access
return functions
def _sort_function_defs(library):
"""Return a topologic sort of FunctionDefs in a library."""
edges = collections.defaultdict(list)
in_count = collections.defaultdict(lambda: 0)
for fdef in library.function:
for dep in _list_function_deps(fdef):
edges[dep].append(fdef.signature.name)
in_count[fdef.signature.name] += 1
ready = [
fdef.signature.name
for fdef in library.function
if in_count[fdef.signature.name] == 0
]
output = []
while ready:
node = ready.pop()
output.append(node)
for dest in edges[node]:
in_count[dest] -= 1
if not in_count[dest]:
ready.append(dest)
if len(output) != len(library.function):
failed_to_resolve = sorted(set(in_count.keys()) - set(output))
raise ValueError("There is a cyclic-dependency between functions. ",
"Could not resolve %r." % (failed_to_resolve,))
reverse = {fdef.signature.name: fdef for fdef in library.function}
return [reverse[x] for x in output]
def _fix_fdef(orig_fdef, functions, shared_name_suffix):
"""Fixes a FunctionDef proto to be loaded in current context.
In particular, when loading a function library into an eager context, one
must rename the functions to avoid conflicts with existent functions.
Args:
orig_fdef: FunctionDef proto to fix. It is not modified.
functions: map from function name to a ConcreteFunction instance.
shared_name_suffix: A unique string for this load which helps to avoid
`shared_name` collisions across loads. Two functions from the same load
using the same `shared_name` still need to share, but functions from
different loads with the same `shared_name` should not.
Returns:
A fixed copy of the original FunctionDef.
"""
fdef = function_pb2.FunctionDef()
fdef.CopyFrom(orig_fdef)
for node_def in fdef.node_def:
if "_gradient_op_type" in node_def.attr:
if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]:
# TODO(andresp): This code assumes that the gradient registered for this
# function call is the default gradient for the function and not a
# custom one.
fname = node_def.attr["f"].func.name
node_def.attr["_gradient_op_type"].s = compat.as_bytes(
functions[fname]._gradient_name) # pylint: disable=protected-access
else:
logging.warning("Importing a function (%s) with ops with custom "
"gradients. Will likely fail if a gradient is "
"requested.", fdef.signature.name)
for _, attr_value in node_def.attr.items():
if attr_value.func.name:
attr_value.func.name = functions[attr_value.func.name].name
# TODO(b/124205571): Avoid accidental sharing and destruction of restored
# resources. For now uniquify "shared_name" when loading functions to avoid
# sharing.
if "shared_name" in node_def.attr:
node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix)
fdef.signature.name = _clean_function_name(fdef.signature.name)
return fdef
def _list_function_deps(fdef):
# TODO(andresp): Recurse into list attributes and into NameAttrList attrs both
# when listing deps and when fixing them. `function_def_to_graph` also
# requires fixes.
deps = set()
for node_def in fdef.node_def:
for _, attr_value in node_def.attr.items():
if attr_value.WhichOneof("value") == "func":
deps.add(attr_value.func.name)
return deps
def _clean_function_name(name):
"""Vanity function to keep the function names comprehensible."""
# Note: each time a function is wrapped into `function_lib.ConcreteFunction`
# its name becomes "__inference_<orig>_xyz".
match = re.search(r"^__inference_(.*)_\d+$", name)
if match:
return match.group(1)
else:
return name