/
critical_section_ops.py
419 lines (349 loc) · 16 KB
/
critical_section_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Critical Section object and execution logic."""
import collections
import contextlib
import threading
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
__all__ = ["CriticalSection"]
# Graph Keys
CRITICAL_SECTIONS = "critical_sections"
CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
class _ExecutionSignature(
collections.namedtuple("_ExecutionSignature",
("op", "handle",
"resources", "exclusive_resource_access"))):
"""A class storing an `ExecuteInCriticalResource` op and associated attrs."""
pass
def _identity(x):
"""Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
if isinstance(x, tensor_array_ops.TensorArray):
return x.identity()
elif isinstance(x, ops.Operation):
return control_flow_ops.group(x)
elif context.executing_eagerly() and x is None:
return None
else:
return array_ops.identity(x)
def _get_device_or_colocation(op):
return op.device or _get_colocation(op)
def _get_colocation(op):
"""Get colocation symbol from op, if any."""
try:
return op.get_attr("_class")
except (ValueError, AttributeError):
return None
_CRITICAL_SECTION_STACK = threading.local()
def _get_critical_section_stack():
try:
return _CRITICAL_SECTION_STACK.value
except AttributeError:
_CRITICAL_SECTION_STACK.value = []
return _CRITICAL_SECTION_STACK.value
@contextlib.contextmanager
def _push_critical_section_stack(signature):
"""Push a CriticalSection._signature to the thread-local stack.
If the signature is already on the stack, raise an error because it means
we're trying to execute inside the same locked CriticalSection, which
will create a deadlock.
Args:
signature: Tuple of the type `CriticalSection._signature`. Uniquely
identifies a CriticalSection by its `shared_name`, `container`,
and device.
Yields:
An empty value. The context is guaranteed to run without deadlock.
Raises:
ValueError: If the signature is already on the stack.
RuntimeError: If another thread or function modifies the current stack
entry during the yield.
"""
stack = _get_critical_section_stack()
if signature in stack:
raise ValueError(
f"Attempting to lock a CriticalSection (signature={signature}) in which"
" we are already running. This is illegal and may cause deadlocks.")
stack.append(signature)
try:
yield
finally:
received_signature = stack.pop()
if received_signature != signature:
raise RuntimeError(
"CriticalSection stack inconsistency: expected signature "
f"{signature} but received {received_signature}")
@tf_export("CriticalSection")
class CriticalSection:
"""Critical section.
A `CriticalSection` object is a resource in the graph which executes subgraphs
in **serial** order. A common example of a subgraph one may wish to run
exclusively is the one given by the following function:
```python
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def count():
value = v.read_value()
with tf.control_dependencies([value]):
with tf.control_dependencies([v.assign_add(1)]):
return tf.identity(value)
```
Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
The snapshot value is returned.
If multiple workers or threads all execute `count` in parallel, there is no
guarantee that access to the variable `v` is atomic at any point within
any thread's calculation of `count`. In fact, even implementing an atomic
counter that guarantees that the user will see each value `0, 1, ...,` is
currently impossible.
The solution is to ensure any access to the underlying resource `v` is
only processed through a critical section:
```python
cs = CriticalSection()
f1 = cs.execute(count)
f2 = cs.execute(count)
output = f1 + f2
session.run(output)
```
The functions `f1` and `f2` will be executed serially, and updates to `v`
will be atomic.
**NOTES**
All resource objects, including the critical section and any captured
variables of functions executed on that critical section, will be
colocated to the same device (host and cpu/gpu).
When using multiple critical sections on the same resources, there is no
guarantee of exclusive access to those resources. This behavior is disallowed
by default (but see the kwarg `exclusive_resource_access`).
For example, running the same function in two separate critical sections
will not ensure serial execution:
```python
v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True)
def accumulate(up):
x = v.read_value()
with tf.control_dependencies([x]):
with tf.control_dependencies([v.assign_add(up)]):
return tf.identity(x)
ex1 = CriticalSection().execute(
accumulate, 1.0, exclusive_resource_access=False)
ex2 = CriticalSection().execute(
accumulate, 1.0, exclusive_resource_access=False)
bad_sum = ex1 + ex2
sess.run(v.initializer)
sess.run(bad_sum) # May return 0.0
```
"""
def __init__(self, name=None, shared_name=None,
critical_section_def=None, import_scope=None):
"""Creates a critical section."""
context.ensure_initialized()
if critical_section_def and name is not None:
raise ValueError(f"Arguments critical_section_def={critical_section_def} "
f"and shared_name={shared_name} are mutually exclusive. "
"Please only specify one of them.")
if critical_section_def:
raise ValueError("Argument `critical_section_def` is not supported.")
else:
self._init_from_args(name, shared_name)
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
"""Initialize the CriticalSection from constructor arguments."""
with ops.name_scope(name, "CriticalSection", []) as name:
with ops.init_scope():
# pylint: disable=protected-access
container = ops.get_default_graph()._container
# pylint: enable=protected-access
if shared_name is None:
shared_name = name
if container is None:
container = ""
self._handle = gen_resource_variable_ops.mutex_v2(
shared_name=shared_name, container=container, name=name)
# Get a uniquely identifying signature for the handle.
self._signature = (
container,
# If shared_name is empty, a unique CriticalSection is created.
shared_name or id(self._handle),
_get_device_or_colocation(self._handle))
if not context.executing_eagerly():
ops.add_to_collections(CRITICAL_SECTIONS, self)
@property
def name(self):
return self._handle.op.name
def execute(self, fn, exclusive_resource_access=True, name=None):
"""Execute function `fn()` inside the critical section.
`fn` should not accept any arguments. To add extra arguments to when
calling `fn` in the critical section, create a lambda:
```python
critical_section.execute(lambda: fn(*my_args, **my_kwargs))
```
Args:
fn: The function to execute. Must return at least one tensor.
exclusive_resource_access: Whether the resources required by
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
You may want to set this to `False` if you will be accessing a
resource in read-only mode in two different CriticalSections.
name: The name to use when creating the execute operation.
Returns:
The tensors returned from `fn()`.
Raises:
ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
or lazy way that may cause a deadlock.
ValueError: If `exclusive_resource_access == True` and
another `CriticalSection` has an execution requesting the same
resources as `fn``. Note, even if `exclusive_resource_access` is
`True`, if another execution in another `CriticalSection` was created
without `exclusive_resource_access=True`, a `ValueError` will be raised.
"""
with ops.name_scope(name, "critical_section_execute", []):
# Ensure that mutex locking only happens *after* all args and
# kwargs have been executed. This avoids certain types of deadlocks.
with _push_critical_section_stack(self._signature):
lock = gen_resource_variable_ops.mutex_lock(self._handle)
if not context.executing_eagerly():
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
# Operations created by other threads.
with ops.get_default_graph()._lock: # pylint: disable=protected-access
existing_ops = ops.get_default_graph().get_operations()
with ops.control_dependencies([lock]):
r = fn()
# TODO(ebrevdo): If creating critical sections in a python loop,
# this makes graph creation time quadratic. Revisit if this
# becomes a problem.
created_ops = (set(ops.get_default_graph().get_operations())
.difference(existing_ops))
else:
with ops.control_dependencies([lock]):
r = fn()
if not context.executing_eagerly():
self._add_control_dependencies_to_lock(created_ops, lock.op)
# captured_resources is a list of resources that are directly
# accessed only by ops created during fn(), not by any
# ancestors of those ops in the graph.
captured_resources = object_identity.ObjectIdentitySet([
input_ for op in created_ops
for input_ in op.inputs
if input_.dtype == dtypes.resource
])
# NOTE(ebrevdo): The only time self._is_self_handle() is True
# in this call is if one of the recently created ops, within
# the execute(), themselves attempt to access the
# CriticalSection. This will cause a deadlock.
if any(self._is_self_handle(x) for x in captured_resources):
raise ValueError(
"Attempting to lock a CriticalSection in which we are "
f"already running (signature={self._signature}). This is illegal "
"and may cause deadlocks.")
self._check_multiple_access_to_resources(
captured_resources, exclusive_resource_access)
r_flat = [_identity(x) for x in nest.flatten(r)]
with ops.control_dependencies(r_flat):
# The identity must run on the same machine as self._handle
with ops.colocate_with(self._handle):
# Do not use array_ops.identity as there are special
# optimizations within TensorFlow which seem to elide it
# even when optimizations are disabled(!).
ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
lock)
# Make sure that if any element of r is accessed, all of
# them are executed together.
r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
with ops.control_dependencies([ensure_lock_exists]):
outputs = nest.map_structure(_identity, r)
if not context.executing_eagerly():
signature = _ExecutionSignature(
op=lock.op,
handle=self._handle,
resources=list(captured_resources),
exclusive_resource_access=exclusive_resource_access)
ops.add_to_collections(
CRITICAL_SECTION_EXECUTIONS, signature)
return outputs
def _add_control_dependencies_to_lock(self, created_ops, lock_op):
"""To avoid deadlocks, all args must be executed before lock_op."""
# Get all arguments (explicit and captured) of all ops created by fn().
all_args = set([input_.op for op in created_ops for input_ in op.inputs])
all_args.update(
input_op for op in created_ops for input_op in op.control_inputs)
# Unfortunately, we can't use sets throughout because TF seems to
# create new Operation objects for the same op sometimes; and we
# can't rely on id(op).
# pylint: disable=protected-access
all_args_dict = dict((op._id, op) for op in all_args)
# Remove ops created within fn, or that lock_op already has a
# control dependency on. Also remove a possible self-loop.
for op in created_ops:
all_args_dict.pop(op._id, None)
for op in lock_op.control_inputs:
all_args_dict.pop(op._id, None)
for input_ in lock_op.inputs:
all_args_dict.pop(input_.op._id, None)
all_args_dict.pop(lock_op._id, None)
all_args = all_args_dict.values()
if not all_args:
# No control dependencies to add; return early.
return
# This group is important: it ensures that any ops in all_args
# outside the control context of the lock_op (and this fn, which
# runs in the same context) are added to this context before
# being added to the control dependencies of lock_op.
all_args = control_flow_ops.group(*all_args)
lock_op._add_control_input(all_args)
# pylint: enable=protected-access
def _is_self_handle(self, x):
"""Check if the tensor `x` is the same Mutex as `self._handle`."""
if isinstance(x, ops.EagerTensor):
return x is self._handle
return (x.op.type == "MutexV2"
# blank shared_name means the op will create a unique one.
and x.op.get_attr("shared_name")
and (x.op.get_attr("shared_name") ==
self._handle.op.get_attr("shared_name"))
and (x.op.device == self._handle.op.device
or _get_colocation(x.op) == _get_colocation(self._handle.op)))
def _check_multiple_access_to_resources(
self, captured_resources, exclusive_resource_access):
"""Raise if captured_resources are accessed by another CriticalSection.
Args:
captured_resources: Set of tensors of type resource.
exclusive_resource_access: Whether this execution requires exclusive
resource access.
Raises:
ValueError: If any tensors in `captured_resources` are also accessed
by another `CriticalSection`, and at least one of them requires
exclusive resource access.
"""
# Collections and op introspection does not work in eager
# mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway.
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
if self._is_self_handle(sg.handle):
# Other executions in the same critical section are allowed.
continue
if not (exclusive_resource_access or sg.exclusive_resource_access):
# Neither execution requested exclusive access.
continue
resource_intersection = captured_resources.intersection(sg.resources)
if resource_intersection:
raise ValueError(
"This execution would access resources: "
f"{list(resource_intersection)}. Either this lock "
f"(CriticalSection: {self._handle}) or lock '{sg}' "
f"(CriticalSection: {sg.handle}) requested exclusive resource "
"access of this resource. Did you mean to call execute with "
"keyword argument exclusive_resource_access=False?")