/
context.py
2769 lines (2200 loc) · 93.7 KB
/
context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""State management for eager execution."""
import collections
import contextlib
import copy
import gc
import os
import random
import threading
from absl import logging
import numpy as np
from tensorflow.core.framework import function_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import coordination_config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tfe
from tensorflow.python import tf2
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import executor
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import tfrt_utils
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
GRAPH_MODE = 0
EAGER_MODE = 1
default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE
# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
# new_device_spec).
# Note that we do not protect this with a lock and instead rely on python's GIL
# and the idempotent nature of writes to provide thread safety.
_device_parsing_cache = {}
_starting_device_spec = pydev.DeviceSpec.from_string("")
_MAXINT32 = 2**31 - 1
DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
SYNC = 0
ASYNC = 1
_KEEP_ALIVE_SECS = 600
_python_eager_context_create_counter = monitoring.Counter(
"/tensorflow/api/python/eager_context_create_counter",
"Counter for number of eager contexts created in Python.")
# Re-exporting through context.
is_tfrt_enabled = tfrt_utils.enabled
# This flag and the associated environment var are transient and will eventually
# be removed, once this experiment is enabled by default.
_JIT_COMPILE_REWRITE_ENABLED = os.getenv("TF_JIT_COMPILE_REWRITE") == "1"
def run_eager_op_as_function_enabled():
return True
# This method should only be called after the context has beein initialized.
def enable_jit_compile_rewrite():
"""Run jit_compile functions through rewrite pass.
This runs jit_compile functions through all of the multidevice function
rewrite passes.
"""
global _JIT_COMPILE_REWRITE_ENABLED
_JIT_COMPILE_REWRITE_ENABLED = True
if context_safe() is not None:
context_safe().jit_compile_rewrite = True
# This method should only be called after the context has been initialized.
def disable_jit_compile_rewrite():
global _JIT_COMPILE_REWRITE_ENABLED
_JIT_COMPILE_REWRITE_ENABLED = False
if context_safe() is not None:
context_safe().jit_compile_rewrite = False
def jit_compile_rewrite_enabled():
if context_safe() is not None:
return context_safe().jit_compile_rewrite
return _JIT_COMPILE_REWRITE_ENABLED
# Expose it as internally public APIs for Keras use cases in b/171080602.
tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled)
class _EagerTensorCache(object):
"""Simple cache which evicts items based on length in a FIFO manner."""
__slots__ = ["_data", "_max_items", "_max_tensor_size"]
def __init__(self, max_items=256, max_tensor_size=10000):
self._data = collections.OrderedDict()
self._max_items = max_items
self._max_tensor_size = max_tensor_size
def put(self, key, value):
if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access
return
self._data[key] = value
if len(self._data) > self._max_items:
self._data.popitem(last=False)
def get(self, key):
return self._data.get(key, None)
def flush(self):
self._data.clear()
class FunctionCallOptions:
"""Options applied at call sites of eager functions.
Eager functions are functions decorated with tf.contrib.eager.defun.
"""
__slots__ = ["_config_proto_serialized", "_executor_type"]
def __init__(self, executor_type=None, config_proto=None):
"""Constructor.
Args:
executor_type: (optional) name of the executor to be used to execute the
eager function. If None or an empty string, the default Tensorflow
executor will be used.
config_proto: (optional) a `config_pb2.ConfigProto` proto or a serialized
string of that proto. The config used by Grappler when optimizing the
function graph. Each concrete function is optimized the first time is
called. Changing config_proto after the first call has no effect. If
config_proto is None, an empty RewriterConfig will be used.
"""
self.config_proto_serialized = config_proto
self.executor_type = executor_type
@property
def executor_type(self):
return self._executor_type
@executor_type.setter
def executor_type(self, executor_type):
self._executor_type = executor_type
@property
def config_proto_serialized(self):
return self._config_proto_serialized
@config_proto_serialized.setter
def config_proto_serialized(self, config):
if isinstance(config, config_pb2.ConfigProto):
self._config_proto_serialized = config.SerializeToString(
deterministic=True)
elif isinstance(config, str):
self._config_proto_serialized = config
elif config is None:
self._config_proto_serialized = (
config_pb2.ConfigProto().SerializeToString())
else:
raise ValueError("the rewriter config must be either a "
"config_pb2.ConfigProto, or a serialized string of that "
"proto or None. got: {}".format(type(config)))
# Map from context_id (an int) to _TensorCaches.
# Dicts are thread safe in CPython.
# TODO(iga): Remove this once TensorCaches are moved to C++.
_tensor_caches_map = {}
class _TensorCaches(threading.local):
"""Thread local tensor caches."""
__slots__ = ["_ones_rank_cache", "_zeros_cache"]
def __init__(self):
super().__init__()
self._ones_rank_cache = None
self._zeros_cache = None
@property
def ones_rank_cache(self):
if not self._ones_rank_cache:
self._ones_rank_cache = _EagerTensorCache()
return self._ones_rank_cache
@property
def zeros_cache(self):
if not self._zeros_cache:
self._zeros_cache = _EagerTensorCache()
return self._zeros_cache
ContextSwitch = collections.namedtuple(
"ContextSwitch",
["is_building_function", "enter_context_fn", "device_stack"])
# `_ContextSwitchStack` is a `threading.local` to match the semantics of
# ``DefaultGraphStack`, which is also a `threading.local`.
class _ContextSwitchStack(threading.local):
"""A thread-local stack of context switches."""
def __init__(self, eager):
super().__init__()
self.stack = []
if eager:
# Initialize the stack with a pointer to enter the eager context; this
# ensures that the fact that eager execution was enabled is propagated
# across threads, since (1) `enable_eager_execution` modifies a
# process-level flag (`default_execution_mode`) and (2) `__init__` is
# called each time a threading.local object is used in a separate thread.
self.push(
is_building_function=False,
enter_context_fn=eager_mode,
device_stack=None)
def push(self, is_building_function, enter_context_fn, device_stack):
"""Push metadata about a context switch onto the stack.
A context switch can take any one of the two forms: installing a graph as
the default graph, or entering the eager context. For each context switch,
we record whether or not the entered context is building a function.
Args:
is_building_function: (bool.) Whether the context is building a function.
enter_context_fn: (function.) A callable that executes the context switch.
For example, `graph.as_default` or `eager_mode`.
device_stack: If applicable, the device function stack for this graph.
When breaking out of graphs in init_scope, the innermost nonempty device
stack is used. Eager contexts put `None` here and the value is never
used.
"""
self.stack.append(
ContextSwitch(is_building_function, enter_context_fn, device_stack))
def pop(self):
"""Pop the stack."""
self.stack.pop()
@tf_export("config.LogicalDevice")
class LogicalDevice(
collections.namedtuple("LogicalDevice", ["name", "device_type"])):
"""Abstraction for a logical device initialized by the runtime.
A `tf.config.LogicalDevice` corresponds to an initialized logical device on a
`tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors
and operations can be placed on a specific logical device by calling
`tf.device` with a specified `tf.config.LogicalDevice`.
Fields:
name: The fully qualified name of the device. Can be used for Op or function
placement.
device_type: String declaring the type of device such as "CPU" or "GPU".
"""
pass
@tf_export("config.LogicalDeviceConfiguration",
"config.experimental.VirtualDeviceConfiguration")
class LogicalDeviceConfiguration(
collections.namedtuple("LogicalDeviceConfiguration", [
"memory_limit", "experimental_priority", "experimental_device_ordinal"
])):
"""Configuration class for a logical devices.
The class specifies the parameters to configure a `tf.config.PhysicalDevice`
as it is initialized to a `tf.config.LogicalDevice` during runtime
initialization. Not all fields are valid for all device types.
See `tf.config.get_logical_device_configuration` and
`tf.config.set_logical_device_configuration` for usage examples.
Fields:
memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual
device. Currently only supported for GPUs.
experimental_priority: (optional) Priority to assign to a virtual device.
Lower values have higher priorities and 0 is the default.
Within a physical GPU, the GPU scheduler will prioritize ops on virtual
devices with higher priority. Currently only supported for Nvidia GPUs.
experimental_device_ordinal: (optional) Ordinal number to order the virtual
device.
LogicalDevice with lower ordinal number will receive a lower device id.
Physical device id and location in the list is used to break ties.
Currently only supported for Nvidia GPUs.
"""
def __new__(cls,
memory_limit=None,
experimental_priority=None,
experimental_device_ordinal=None):
return super().__new__(cls, memory_limit, experimental_priority,
experimental_device_ordinal)
@tf_export("config.PhysicalDevice")
class PhysicalDevice(
collections.namedtuple("PhysicalDevice", ["name", "device_type"])):
"""Abstraction for a locally visible physical device.
TensorFlow can utilize various devices such as the CPU or multiple GPUs
for computation. Before initializing a local device for use, the user can
customize certain properties of the device such as it's visibility or memory
configuration.
Once a visible `tf.config.PhysicalDevice` is initialized one or more
`tf.config.LogicalDevice` objects are created. Use
`tf.config.set_visible_devices` to configure the visibility of a physical
device and `tf.config.set_logical_device_configuration` to configure multiple
`tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is
useful when separation between models is needed or to simulate a multi-device
environment.
Fields:
name: Unique identifier for device.
device_type: String declaring the type of device such as "CPU" or "GPU".
"""
pass
class _AtomicCounter(object):
"""A simple atomic counter."""
__slots__ = ["_value", "_lock"]
def __init__(self):
self._value = 0
self._lock = threading.Lock()
def increment_and_get(self):
with self._lock:
self._value += 1
return self._value
_context_id_counter = _AtomicCounter()
class _TensorCacheDeleter(object):
"""Deletes tensor caches for a given context."""
__slots__ = ["_context_id"]
def __init__(self, context_id):
self._context_id = context_id
def __del__(self):
if _tensor_caches_map is None:
return
if self._context_id in _tensor_caches_map:
del _tensor_caches_map[self._context_id]
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context:
"""Environment in which eager operations execute."""
# TODO(agarwal): create and link in some documentation for `execution_mode`.
# pylint: disable=redefined-outer-name
def __init__(self,
config=None,
device_policy=None,
execution_mode=None,
server_def=None):
"""Creates a new Context.
Args:
config: (Optional.) A `ConfigProto` protocol buffer with configuration
options for the Context. Note that a lot of these options may be
currently unimplemented or irrelevant when eager execution is enabled.
device_policy: (Optional.) What policy to use when trying to run an
operation on a device with inputs which are not on that device. When set
to None, an appropriate value will be picked automatically. The value
picked may change between TensorFlow releases. Defaults to
DEVICE_PLACEMENT_SILENT.
Valid values:
- DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
correct.
- DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right
device but raises a warning.
- DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide
performance problems.
- DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
raising errors on the other ones.
execution_mode: (Optional.) Policy controlling how operations dispatched
are actually executed. When set to None, an appropriate value will be
picked automatically. The value picked may change between TensorFlow
releases.
Valid values:
- SYNC: executes each operation synchronously.
- ASYNC: executes each operation asynchronously. These operations may
return "non-ready" handles.
server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution
on remote devices. GrpcServers need to be started by creating an
identical server_def to this, and setting the appropriate task_indexes,
so that the servers can communicate. It will then be possible to execute
operations on remote devices.
Raises:
ValueError: If execution_mode is not valid.
"""
# This _id is used only to index the tensor caches.
# TODO(iga): Remove this when tensor caches are moved to C++.
self._id = _context_id_counter.increment_and_get()
self._tensor_cache_deleter = _TensorCacheDeleter(self._id)
_tensor_caches_map[self._id] = _TensorCaches()
self._config = config
self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
self,
is_eager=lambda: default_execution_mode == EAGER_MODE,
device_spec=_starting_device_spec)
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
self._context_handle = None
self._context_devices = None
self._seed = None
self._initialize_lock = threading.Lock()
self._initialized = False
if device_policy is None:
device_policy = DEVICE_PLACEMENT_SILENT
self._device_policy = device_policy
self._mirroring_policy = None
if execution_mode not in (None, SYNC, ASYNC):
raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" %
execution_mode)
if execution_mode is None:
execution_mode = SYNC
self._default_is_async = execution_mode == ASYNC
self._use_tfrt = is_tfrt_enabled()
self._use_tfrt_distributed_runtime = None
self._jit_compile_rewrite = jit_compile_rewrite_enabled()
self._server_def = server_def
self._collective_ops_server_def = None
self._collective_leader = None
self._collective_scoped_allocator_enabled_ops = None
self._collective_use_nccl_communication = None
self._collective_device_filters = None
self._coordination_service_config = None
self._device_lock = threading.Lock()
self._physical_devices = None
self._physical_device_to_index = None
self._pluggable_devices = None
self._visible_device_list = []
self._memory_growth_map = None
self._virtual_device_map = {}
# Values set after construction
self._optimizer_jit = None
self._intra_op_parallelism_threads = None
self._inter_op_parallelism_threads = None
self._soft_device_placement = None
self._log_device_placement = None
self._operation_timeout_in_ms = None
self._enable_mlir_graph_optimization = None
self._optimizer_experimental_options = {}
_python_eager_context_create_counter.get_cell().increase_by(1)
self._is_global_context = False
# pylint: enable=redefined-outer-name
def _set_global_seed(self, seed):
"""Set a global eager mode seed for random ops."""
self._seed = seed
# `random.Random(seed)` needs `seed` to be hashable, while values of type
# e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them
# to int.
try:
hash(seed)
except TypeError:
seed = int(np.array(seed))
self._rng = random.Random(seed)
# Also clear the kernel cache, to reset any existing seeds
if self._context_handle is not None:
pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
def _internal_operation_seed(self):
"""Returns a fake operation seed.
In eager mode, user shouldn't set or depend on operation seed.
Here, we generate a random seed based on global seed to make
operation's randomness different and depend on the global seed.
Returns:
A fake operation seed based on global seed.
"""
return self._rng.randint(0, _MAXINT32)
def _initialize_logical_devices(self):
"""Helper to initialize devices."""
# Store list of devices
logical_devices = []
context_devices = []
device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
try:
self._num_gpus = 0
current_job, current_task = None, None
server_def = self._server_def or self._collective_ops_server_def
if server_def is not None:
current_job, current_task = server_def.job_name, server_def.task_index
for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
context_devices.append(pydev.canonical_name(dev_name))
spec = pydev.DeviceSpec.from_string(dev_name)
# If the job is localhost, we assume that the cluster has not yet been
# configured and thus clear the job, replica & task.
if spec.job == "localhost":
spec = spec.replace(job=None, replica=None, task=None)
logical_devices.append(
LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
if (dev_type == "GPU" and spec.job == current_job and
spec.task == current_task):
self._num_gpus += 1
finally:
self._logical_devices = logical_devices
self._context_devices = context_devices
pywrap_tfe.TF_DeleteDeviceList(device_list)
def ensure_initialized(self):
"""Initialize handle and devices if not already done so."""
if self._initialized:
return
with self._initialize_lock:
if self._initialized:
return
assert self._context_devices is None
opts = pywrap_tfe.TFE_NewContextOptions()
try:
config_str = self.config.SerializeToString()
pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
if self._device_policy is not None:
pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
opts, self._device_policy)
if self._mirroring_policy is not None:
pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
opts, self._mirroring_policy)
if self._default_is_async == ASYNC:
pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
if self._use_tfrt is not None:
pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
# pylint: disable=g-backslash-continuation
if self._use_tfrt is not None and \
self._use_tfrt_distributed_runtime is not None:
pywrap_tfe.TFE_ContextOptionsSetTfrtDistributedRuntime(
opts, self._use_tfrt_distributed_runtime)
pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True)
pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite(
opts, self._jit_compile_rewrite)
context_handle = pywrap_tfe.TFE_NewContext(opts)
finally:
pywrap_tfe.TFE_DeleteContextOptions(opts)
assert not (self._server_def and self._collective_ops_server_def), (
"Cannot enable remote execution as well as collective ops at the "
"moment. If this is important to you, please file an issue.")
if self._server_def is not None:
server_def_str = self._server_def.SerializeToString()
pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
server_def_str)
elif self._collective_ops_server_def is not None:
server_def_str = self._collective_ops_server_def.SerializeToString()
pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
self._context_handle = context_handle
self._initialize_logical_devices()
self._initialized = True
if self._is_global_context:
pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle)
def mark_as_global_context(self):
# If the context was already initialized, publish it. Otherwise wait with
# publication until it's initialized.
if self._initialized:
pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle)
self._is_global_context = True
def _clear_caches(self):
self.ones_rank_cache().flush()
self.zeros_cache().flush()
pywrap_tfe.TFE_ClearScalarCache()
def get_server_def(self):
return self._server_def
def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
"""Allow setting a server_def on the context.
When a server def is replaced, it effectively clears a bunch of caches
within the context. If you attempt to use a tensor object that was pointing
to a tensor on the remote device, it will raise an error.
Args:
server_def: A tensorflow::ServerDef proto. Enables execution on remote
devices.
keep_alive_secs: Num. seconds after which the remote end will hang up. As
long as the client is still alive, the server state for the context will
be kept alive. If the client is killed (or there is some failure), the
server will clean up its context keep_alive_secs after the final RPC it
receives.
Raises:
ValueError: if server_def is None.
"""
if not server_def:
raise ValueError("server_def is None.")
self._server_def = server_def
if self._context_handle:
server_def_str = server_def.SerializeToString()
pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
server_def_str)
self._initialize_logical_devices()
# Clear all the caches in case there are remote tensors in them.
self._clear_caches()
def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
"""Update a server_def on the context.
Args:
server_def: A tensorflow::ServerDef proto. Enables execution on remote
devices.
keep_alive_secs: Num. seconds after which the remote end will hang up. As
long as the client is still alive, the server state for the context will
be kept alive. If the client is killed (or there is some failure), the
server will clean up its context keep_alive_secs after the final RPC it
receives.
Raises:
ValueError: if server_def is None.
"""
if not server_def:
raise ValueError("server_def is None.")
self._server_def = server_def
if self._context_handle:
server_def_str = server_def.SerializeToString()
pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
keep_alive_secs, server_def_str)
self._initialize_logical_devices()
self._clear_caches()
def check_alive(self, worker_name):
"""Checks whether a remote worker is alive or not.
Args:
worker_name: a string representing the remote worker. It must be a fully
specified name like "/job:worker/replica:0/task:0".
Returns:
a boolean indicating whether the remote worker is alive or not.
Raises:
ValueError: if context is not initialized.
"""
# TODO(yuefengz): support checking multiple workers.
if self._context_handle:
return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
else:
raise ValueError("Context is not initialized.")
def sync_executors(self):
"""Sync both local executors and the ones on remote workers.
In async execution mode, local function calls can return before the
corresponding remote op/function execution requests are completed. Calling
this method creates a synchronization barrier for remote executors. It only
returns when all remote pending nodes are finished, potentially with errors
if any remote executors are in error state.
Raises:
ValueError: if context is not initialized.
"""
if self._context_handle:
pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle)
else:
raise ValueError("Context is not initialized.")
def clear_executor_errors(self):
"""Clear errors in both local executors and remote workers.
After receiving errors from remote workers, additional requests on the fly
could further taint the status on the remote workers due to the async nature
of remote execution. Calling this method block on waiting for all pending
nodes in remote executors to finish and clear their error statuses.
Raises:
ValueError: if context is not initialized.
"""
if self._context_handle:
pywrap_tfe.TFE_ContextClearExecutors(self._context_handle)
else:
raise ValueError("Context is not initialized.")
def configure_coordination_service(self,
service_type,
service_leader="",
enable_health_check=True,
cluster_register_timeout_in_ms=0,
heartbeat_timeout_in_ms=0,
coordinated_jobs=None):
"""Enable distributed coordination service with specified configs."""
if self._context_handle:
logging.warning("Configuring coordination service type may not be "
"effective because the context is already initialized.")
config = coordination_config_pb2.CoordinationServiceConfig()
config.service_type = service_type
if service_leader:
config.service_leader = pydev.canonical_name(service_leader)
config.enable_health_check = enable_health_check
config.cluster_register_timeout_in_ms = cluster_register_timeout_in_ms
config.heartbeat_timeout_in_ms = heartbeat_timeout_in_ms
if coordinated_jobs is not None:
if isinstance(coordinated_jobs, list):
config.coordinated_job_list.extend(coordinated_jobs)
else:
raise ValueError("`coordinated_jobs` must be list[CoordinatedJob] or "
"None, but got: %s" % (coordinated_jobs,))
self._coordination_service_config = config
@property
def coordination_service(self):
return self._coordination_service_config
def set_config_key_value(self, key, value):
ensure_initialized()
pywrap_tfe.TFE_InsertConfigKeyValue(self._context_handle, key, value)
def get_config_key_value(self, key):
ensure_initialized()
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key, buffer_)
value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
return value
def delete_config_key_value(self, key):
ensure_initialized()
pywrap_tfe.TFE_DeleteConfigKeyValue(self._context_handle, key)
def report_error_to_cluster(self, error_code, error_message):
"""Report error to other members in a multi-client cluster.
Args:
error_code: a `tf.errors` error code.
error_message: a string. The error message.
"""
if self._context_handle:
pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code,
error_message)
else:
raise ValueError("Context is not initialized.")
def clear_kernel_cache(self):
"""Clear kernel cache and reset all stateful kernels."""
if self._context_handle is not None:
pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
def enable_collective_ops(self, server_def):
"""Enable distributed collective ops with an appropriate server_def.
Args:
server_def: A tensorflow::ServerDef proto. Enables execution on remote
devices.
Raises:
ValueError: if server_def is None.
RuntimeError: if this method is not called at program startup.
"""
if not server_def:
raise ValueError("server_def is None.")
self._collective_ops_server_def = server_def
# TODO(b/129298253): Allow creating datasets/tensors before enabling
# collective ops.
if self._context_handle is not None:
logging.warning("Enabling collective ops after program startup may cause "
"error when accessing previously created tensors.")
with self._initialize_lock:
assert self._initialized
server_def_str = self._collective_ops_server_def.SerializeToString()
pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str)
self._initialize_logical_devices()
self._clear_caches()
def configure_collective_ops(
self,
collective_leader="",
scoped_allocator_enabled_ops=("CollectiveReduce",),
use_nccl_communication=False,
device_filters=None):
"""Configure collective ops.
Collective group leader is necessary for collective ops to run, other
configurations are mainly for the purpose of performance.
Args:
collective_leader: a device string for collective leader, e.g.
"/job:worker/replica:0/task:0"; empty string means local execution of
collective ops.
scoped_allocator_enabled_ops: a tuple or a list of op names for scoped
allocator to run with.
use_nccl_communication: whether to use nccl communication for collective
ops.
device_filters: a tuple or a list of device strings. If set, corresponding
task can only see the devices filtered by these device filters.
Raises:
RuntimeError: if this method is not called at program startup.
"""
if self._collective_leader is not None:
if (self._collective_leader != collective_leader or
self._collective_scoped_allocator_enabled_ops !=
scoped_allocator_enabled_ops or
self._collective_use_nccl_communication != use_nccl_communication or
self._collective_device_filters != device_filters):
raise ValueError("Collective ops are already configured.")
else:
return
if self._context_handle is not None:
raise RuntimeError("Collective ops must be configured at program startup")
self._collective_leader = collective_leader
self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops
self._collective_use_nccl_communication = use_nccl_communication
self._collective_device_filters = device_filters
def abort_collective_ops(self, code, message):
"""Abort the collective ops.
This is intended to be used when a peer failure is detected, which allows
the user to handle the case instead of hanging. This aborts all on-going
collectives. After all subsequent collectives error immediately, and you
need to reset_context() to use collectives again.
Args:
code: a `tf.errors` error code.
message: a string. The error message.
"""
self.ensure_initialized()
pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
def check_collective_ops_peer_health(self, task, timeout_in_ms):
"""Check collective peer health.
This probes each task to see if they're still alive. Note that restarted
tasks are considered a different one, and they're considered not healthy.
This should only be used in multi client multi worker training.
Args:
task: a task string, must be in the format of /job:xxx/replica:0/task:N.
timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
Raises:
tf.errors.UnavailableError: when a peer is down.
tf.errors.FailedPreconditionError: when a peer is a different one from the
one this task has talked to, e.g. the peer has restarted.
tf.errors.InvalidArgumentError: when the task string is invalid.
"""
self.ensure_initialized()
pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
timeout_in_ms)
@property
def _handle(self):
if self._context_handle is None:
raise AssertionError("Context must be initialized first.")
return self._context_handle
@property
def _devices(self):
if self._context_devices is None:
raise AssertionError("Context must be initialized first.")
return self._context_devices
def __str__(self):
if self._context_handle is None:
return "Eager TensorFlow Context. Devices currently uninitialized."
else:
devices = self._devices
lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
for i, d in enumerate(devices):
lines.append(" Device %d: %s" % (i, d))
return "\n".join(lines)
@tf_contextlib.contextmanager
def _mode(self, mode):
"""A context manager to allow setting the mode to EAGER/GRAPH."""
ctx = self._thread_local_data
old_is_eager = ctx.is_eager
ctx.is_eager = mode == EAGER_MODE
if mode == EAGER_MODE:
# Entering graph mode does not provide us with sufficient information to
# record a context switch; graph-based context switches are only logged
# when a graph is registered as the default graph.
self.context_switches.push(False, eager_mode, None)
try:
yield
finally:
ctx.is_eager = old_is_eager
if mode == EAGER_MODE:
self.context_switches.pop()
def executing_eagerly(self):
"""Returns True if current thread has eager executing enabled."""
return self._thread_local_data.is_eager
def ones_rank_cache(self):
"""Per-device cache for scalars."""
return _tensor_caches_map[self._id].ones_rank_cache
def zeros_cache(self):
"""Per-device cache for scalars."""
return _tensor_caches_map[self._id].zeros_cache
@property
def scope_name(self):
"""Returns scope name for the current thread."""
return self._thread_local_data.scope_name
@scope_name.setter
def scope_name(self, s):
"""Sets scope name for the current thread."""
self._thread_local_data.scope_name = s
@property
def device_name(self):
"""Returns the device name for the current thread."""
return self._thread_local_data.device_name
@property
def device_spec(self):
"""Returns the device spec for the current thread."""
return self._thread_local_data.device_spec
def _set_device(self, device_name, device_spec):
self._thread_local_data.device_name = device_name
self._thread_local_data.device_spec = device_spec
def device(self, name):
"""Context-manager to force placement of operations and Tensors on a device.
Args:
name: Name of the device or None to get default placement.
Returns:
Context manager that forces device placement.
Raises:
ValueError: If name is not a string or is an invalid device name.
RuntimeError: If device scopes are not properly nested.
"""
if isinstance(name, LogicalDevice):
name = name.name