-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathresource.py
307 lines (247 loc) · 10.5 KB
/
resource.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Definitions for resource-type trackable object classes."""
import contextlib
import copy
import weakref
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.trackable import base
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
# global _RESOURCE_TRACKER_STACK
_RESOURCE_TRACKER_STACK = []
class ResourceTracker:
"""An object that tracks a list of resources."""
__slots__ = ["_resources"]
def __init__(self):
self._resources = []
@property
def resources(self):
return self._resources
def add_resource(self, resource):
self._resources.append(resource)
@tf_contextlib.contextmanager
def resource_tracker_scope(resource_tracker):
"""A context to manage resource trackers.
Use this in order to collect up all resources created within a block of code.
Example usage:
```python
resource_tracker = ResourceTracker()
with resource_tracker_scope(resource_tracker):
resource = TrackableResource()
assert resource_tracker.resources == [resource]
Args:
resource_tracker: The passed in ResourceTracker object
Yields:
A scope in which the resource_tracker is active.
"""
global _RESOURCE_TRACKER_STACK
old = list(_RESOURCE_TRACKER_STACK)
_RESOURCE_TRACKER_STACK.append(resource_tracker)
try:
yield
finally:
_RESOURCE_TRACKER_STACK = old
def _make_getter(captured_getter, captured_previous):
"""To avoid capturing loop variables."""
def getter(*args, **kwargs):
return captured_getter(captured_previous, *args, **kwargs)
return getter
class _ResourceMetaclass(type):
"""Metaclass for CapturableResource."""
def __call__(cls, *args, **kwargs):
def default_resource_creator(next_creator, *a, **kw):
assert next_creator is None
obj = cls.__new__(cls, *a, **kw)
obj.__init__(*a, **kw)
return obj
previous_getter = lambda *a, **kw: default_resource_creator(None, *a, **kw)
resource_creator_stack = ops.get_default_graph()._resource_creator_stack
for getter in resource_creator_stack[cls._resource_type()]:
previous_getter = _make_getter(getter, previous_getter)
return previous_getter(*args, **kwargs)
class CapturableResource(base.Trackable, metaclass=_ResourceMetaclass):
"""Holds a Tensor which a tf.function can capture.
`CapturableResource`s are discovered by traversing the graph of object
attributes, e.g. during `tf.saved_model.save`. They are excluded from the
scope-based tracking of `TrackableResource`; generally things that require
initialization should inherit from `TrackableResource` instead of
`CapturableResource` directly.
"""
def __init__(self, device=""):
"""Initialize the `CapturableResource`.
Args:
device: A string indicating a required placement for this resource,
e.g. "CPU" if this resource must be created on a CPU device. A blank
device allows the user to place resource creation, so generally this
should be blank unless the resource only makes sense on one device.
"""
self._resource_handle_value = None
self._resource_device = device
self._self_destruction_context = (
context.eager_mode if context.executing_eagerly()
else ops.get_default_graph().as_default)
@classmethod
def _resource_type(cls):
return cls.__name__
@property
def _destruction_context(self):
return getattr(self, "_self_destruction_context",
# no-op context
contextlib.suppress)
@_destruction_context.setter
def _destruction_context(self, destruction_context):
self._self_destruction_context = destruction_context
def _create_resource(self):
"""A function that creates a resource handle."""
raise NotImplementedError("TrackableResource._create_resource not "
"implemented.")
@property
def _resource_handle(self):
return self._resource_handle_value
@_resource_handle.setter
def _resource_handle(self, value):
if isinstance(value, (ops.Tensor, ops.EagerTensor)):
value._parent_trackable = weakref.ref(self) # pylint: disable=protected-access
self._resource_handle_value = value
def _initialize(self):
"""A function that initializes the resource. Optional."""
pass
def _destroy_resource(self):
"""A function that destroys the resource. Optional."""
pass
@property
def resource_handle(self):
"""Returns the resource handle associated with this Resource."""
if self._resource_handle is None:
with ops.device(self._resource_device):
self._resource_handle = self._create_resource()
return self._resource_handle
def _export_to_saved_model_graph(
self, object_map, tensor_map, **unused_kwargs):
"""For implementing `Trackable`."""
new_obj = copy.copy(self)
# pylint: disable=protected-access
with ops.device(self._resource_device):
new_resource = new_obj._create_resource()
new_obj._resource_handle = new_resource
# pylint: enable=protected-access
object_map[self] = new_obj
tensor_map[self.resource_handle] = new_resource
return [self.resource_handle]
def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
children = super()._trackable_children(save_type, **kwargs)
if save_type == "savedmodel":
@def_function.function(input_signature=[], autograph=False)
def _creator():
resource = self._create_resource()
return resource
@def_function.function(input_signature=[], autograph=False)
def _initializer():
self._initialize()
return 1 # Dummy return
@def_function.function(input_signature=[], autograph=False)
def _destroyer():
self._destroy_resource()
return 1 # Dummy return
children.update({
"_create_resource": _creator,
"_initialize": _initializer,
"_destroy_resource": _destroyer,
})
return children
def __del__(self):
try:
# Outer race condition: on program exit, the destruction context may be
# deleted before this __del__ is called. At this point we can safely
# exit without calling _destroy_resource() and let Python handle things.
with self._destruction_context():
# Inner race condition: possible between this and `ScopedTFFunction`
# whereby if an entire garbage collection chain containing both
# objects is moved to unreachable during the same garbage collection
# cycle, the __del__ for `ScopedTFFunction` can be collected before
# this method is called. In that case, we can't do much but
# continue.
self._destroy_resource()
except Exception: # pylint: disable=broad-except
# Silence all error logs that occur when attempting to destroy this
# resource.
pass
@tf_export("saved_model.experimental.TrackableResource")
class TrackableResource(CapturableResource):
"""Holds a Tensor which a tf.function can capture.
A TrackableResource is most useful for stateful Tensors that require
initialization, such as `tf.lookup.StaticHashTable`. `TrackableResource`s
are discovered by traversing the graph of object attributes, e.g. during
`tf.saved_model.save`.
A TrackableResource has three methods to override:
* `_create_resource` should create the resource tensor handle.
* `_initialize` should initialize the resource held at `self.resource_handle`.
* `_destroy_resource` is called upon a `TrackableResource`'s destruction
and should decrement the resource's ref count. For most resources, this
should be done with a call to `tf.raw_ops.DestroyResourceOp`.
Example usage:
>>> class DemoResource(tf.saved_model.experimental.TrackableResource):
... def __init__(self):
... super().__init__()
... self._initialize()
... def _create_resource(self):
... return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2])
... def _initialize(self):
... tf.raw_ops.AssignVariableOp(
... resource=self.resource_handle, value=tf.ones([2]))
... def _destroy_resource(self):
... tf.raw_ops.DestroyResourceOp(resource=self.resource_handle)
>>> class DemoModule(tf.Module):
... def __init__(self):
... self.resource = DemoResource()
... def increment(self, tensor):
... return tensor + tf.raw_ops.ReadVariableOp(
... resource=self.resource.resource_handle, dtype=tf.float32)
>>> demo = DemoModule()
>>> demo.increment([5, 1])
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)>
"""
def __init__(self, device=""):
"""Initialize the `TrackableResource`.
Args:
device: A string indicating a required placement for this resource,
e.g. "CPU" if this resource must be created on a CPU device. A blank
device allows the user to place resource creation, so generally this
should be blank unless the resource only makes sense on one device.
"""
global _RESOURCE_TRACKER_STACK
for resource_tracker in _RESOURCE_TRACKER_STACK:
resource_tracker.add_resource(self)
super().__init__(device=device)
# TODO(b/124205571,b/124092991): Solve destruction of resources.
class RestoredResource(TrackableResource):
"""Restored SavedResource."""
def __init__(self, device=""):
super().__init__(device=device)
@classmethod
def _deserialize_from_proto(cls, object_proto, dependencies, **unused_kwargs):
obj = cls(device=object_proto.resource.device)
resource_creator = dependencies.get("_create_resource")
if resource_creator is not None:
obj._create_resource = resource_creator # pylint: disable=protected-access
return obj
def _add_trackable_child(self, name, value):
setattr(self, name, value)
if (isinstance(value, base.Trackable) and
not isinstance(value, def_function.Function)):
self._track_trackable(value, name)