-
Notifications
You must be signed in to change notification settings - Fork 74k
/
resource_variable_ops.py
2845 lines (2445 loc) · 112 KB
/
resource_variable_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops to use variables as resources."""
# pylint: disable=g-bad-name
import contextlib
import functools
import weakref
from absl import logging
import numpy as np
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
from tensorflow.core.function import trace_type
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.checkpoint import tensor_callable
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.compat import compat as forward_compat
from tensorflow.python.eager import context
from tensorflow.python.eager import record
from tensorflow.python.eager import tape
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import composite_tensor_gradient
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor as tensor_module
from tensorflow.python.framework import tensor_conversion_registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import handle_data_util
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_resource_variable_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.trackable import base as trackable
from tensorflow.python.types import core
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
acd.register_read_only_resource_op("ReadVariableOp")
acd.register_read_only_resource_op("VariableShape")
acd.register_read_only_resource_op("ResourceGather")
acd.register_read_only_resource_op("ResourceGatherNd")
acd.register_read_only_resource_op("_ReadVariablesOp")
# TODO(allenl): Remove this alias and migrate callers.
get_resource_handle_data = handle_data_util.get_resource_handle_data
def get_eager_safe_handle_data(handle):
"""Get the data handle from the Tensor `handle`."""
assert isinstance(handle, tensor_module.Tensor)
if isinstance(handle, ops.EagerTensor):
return handle._handle_data # pylint: disable=protected-access
else:
return get_resource_handle_data(handle)
def _set_handle_shapes_and_types(tensor, handle_data, graph_mode):
"""Sets the shape inference result HandleData on tensor.
Args:
tensor: A `Tensor` or `EagerTensor`.
handle_data: A `CppShapeInferenceResult.HandleData`.
graph_mode: A python bool.
"""
tensor._handle_data = handle_data # pylint: disable=protected-access
if not graph_mode:
return
# Not an EagerTensor, so a graph tensor.
shapes, types = zip(
*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type])
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
shapes = [
[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
if not s.unknown_rank else None for s in shapes
]
with tensor._op.graph._c_graph.get() as c_graph: # pylint: disable=protected-access
pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
c_graph,
tensor._as_tf_output(), # pylint: disable=protected-access
shapes,
ranks,
types)
def _combine_handle_data(handle, initial_value):
"""Concats HandleData from tensors `handle` and `initial_value`.
Args:
handle: A `Tensor` of dtype `resource`.
initial_value: A `Tensor`.
Returns:
A `CppShapeInferenceResult.HandleData`. If `initial_value` has dtype
`variant`, the `HandleData` contains the concatenation of the shape_and_type
from both `handle` and `initial_value`.
Raises:
RuntimeError: If handle, which was returned by VarHandleOp, either has
no handle data, or its len(handle_data.shape_and_type) != 1.
"""
assert handle.dtype == dtypes.resource
variable_handle_data = get_eager_safe_handle_data(handle)
if initial_value.dtype != dtypes.variant:
return variable_handle_data
extra_handle_data = get_eager_safe_handle_data(initial_value)
if extra_handle_data is not None and extra_handle_data.is_set:
if (variable_handle_data is None or not variable_handle_data.is_set or
len(variable_handle_data.shape_and_type) != 1):
raise RuntimeError(
"Expected VarHandleOp to return a length==1 shape_and_type, "
f"but saw: '{variable_handle_data}'")
variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
return variable_handle_data
def _variable_handle_from_shape_and_dtype(shape,
dtype,
shared_name,
name,
graph_mode,
initial_value=None):
"""Create a variable handle, copying in handle data from `initial_value`."""
container = ops.get_default_graph()._container # pylint: disable=protected-access
if container is None:
container = ""
shape = tensor_shape.as_shape(shape)
dtype = dtypes.as_dtype(dtype)
if not graph_mode:
if shared_name is not None:
raise errors.InternalError(
node_def=None,
op=None,
message="Using an explicit shared_name is "
"not allowed when executing eagerly.")
shared_name = context.anonymous_name()
handle = gen_resource_variable_ops.var_handle_op(
shape=shape,
dtype=dtype,
shared_name=shared_name,
debug_name=name,
name=name,
container=container)
if initial_value is None:
initial_value = handle
if graph_mode:
full_handle_data = _combine_handle_data(handle, initial_value)
_set_handle_shapes_and_types(handle, full_handle_data, graph_mode)
return handle
else:
handle_data = handle_data_util.create_handle_data(shape, dtype)
if initial_value is not None and initial_value.dtype == dtypes.variant:
extra_handle_data = get_eager_safe_handle_data(initial_value)
if extra_handle_data is not None and extra_handle_data.is_set:
if (not handle_data.is_set or len(handle_data.shape_and_type) != 1):
raise RuntimeError(
"Expected VarHandleOp to return a length==1 shape_and_type, "
f"but saw: '{handle_data}'")
handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
_set_handle_shapes_and_types(handle, handle_data, graph_mode)
return handle
def eager_safe_variable_handle(initial_value, shape, shared_name, name,
graph_mode):
"""Creates a variable handle with information to do shape inference.
The dtype is read from `initial_value` and stored in the returned
resource tensor's handle data.
If `initial_value.dtype == tf.variant`, we additionally extract the handle
data (if any) from `initial_value` and append it to the `handle_data`.
In this case, the returned tensor's handle data is in the form
```
is_set: true
shape_and_type {
shape {
// initial_value.shape
}
dtype: DT_VARIANT
}
shape_and_type {
// handle_data(initial_value).shape_and_type[0]
}
shape_and_type {
// handle_data(initial_value).shape_and_type[1]
}
...
```
Ops that read from this tensor, such as `ReadVariableOp` and
`AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]`
correspond to the handle data of the variant(s) stored in the Variable.
Args:
initial_value: A `Tensor`.
shape: The shape of the handle data. Can be `TensorShape(None)` (i.e.
unknown shape).
shared_name: A string.
name: A string.
graph_mode: A python bool.
Returns:
The handle, a `Tensor` of type `resource`.
"""
dtype = initial_value.dtype.base_dtype
return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name,
graph_mode, initial_value)
@contextlib.contextmanager
def _handle_graph(handle):
# Note: might have an eager tensor but not be executing eagerly when building
# functions.
if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or
ops.has_default_graph()):
yield
else:
with handle.graph.as_default():
yield
class EagerResourceDeleter:
"""An object which cleans up a resource handle.
An alternative to defining a __del__ method on an object. The intended use is
that ResourceVariables or other objects with resource handles will maintain a
single reference to this object. When the parent object is collected, this
object will be too. Even if the parent object is part of a reference cycle,
the cycle will be collectable.
"""
__slots__ = ["_handle", "_handle_device", "_context"]
def __init__(self, handle, handle_device):
if not isinstance(handle, tensor_module.Tensor):
raise ValueError(
(f"Passed handle={handle} to EagerResourceDeleter. Was expecting "
f"the handle to be a `tf.Tensor`."))
self._handle = handle
self._handle_device = handle_device
# This is held since the __del__ function runs an op, and if the context()
# is collected before this object, there will be a segfault when running the
# op.
self._context = context.context()
def __del__(self):
# Resources follow object-identity when executing eagerly, so it is safe to
# delete the resource we have a handle to.
try:
# A packed EagerTensor doesn't own any resource.
if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed:
return
# This resource was created in eager mode. However, this destructor may be
# running in graph mode (especially during unit tests). To clean up
# successfully, we switch back into eager mode temporarily.
with context.eager_mode():
with ops.device(self._handle_device):
gen_resource_variable_ops.destroy_resource_op(
self._handle, ignore_lookup_error=True)
except TypeError:
# Suppress some exceptions, mainly for the case when we're running on
# module deletion. Things that can go wrong include the context module
# already being unloaded, self._handle._handle_data no longer being
# valid, and so on. Printing warnings in these cases is silly
# (exceptions raised from __del__ are printed as warnings to stderr).
pass # 'NoneType' object is not callable when the handle has been
# partially unloaded.
except AttributeError:
pass # 'NoneType' object has no attribute 'eager_mode' when context has
# been unloaded. Will catch other module unloads as well.
def shape_safe_assign_variable_handle(handle, shape, value, name=None):
"""Helper that checks shape compatibility and assigns variable."""
with _handle_graph(handle):
value_tensor = ops.convert_to_tensor(value)
shape.assert_is_compatible_with(value_tensor.shape)
return gen_resource_variable_ops.assign_variable_op(
handle, value_tensor, name=name)
def _maybe_set_handle_data(dtype, handle, tensor):
if dtype == dtypes.variant:
# For DT_VARIANT types, the handle's shape_and_type[1:] stores the
# variant's handle data. Extract it.
handle_data = get_eager_safe_handle_data(handle)
if handle_data.is_set and len(handle_data.shape_and_type) > 1:
tensor._handle_data = ( # pylint: disable=protected-access
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
def variable_accessed(variable):
"""Records that `variable` was accessed for the tape and FuncGraph."""
if hasattr(ops.get_default_graph(), "watch_variable"):
ops.get_default_graph().watch_variable(variable)
if variable.trainable:
tape.variable_accessed(variable)
def default_variable_creator_v2(next_creator=None, **kwargs):
"""Default variable creator."""
assert next_creator is None
initial_value = kwargs.get("initial_value", None)
trainable = kwargs.get("trainable", None)
validate_shape = kwargs.get("validate_shape", True)
caching_device = kwargs.get("caching_device", None)
name = kwargs.get("name", None)
variable_def = kwargs.get("variable_def", None)
dtype = kwargs.get("dtype", None)
import_scope = kwargs.get("import_scope", None)
constraint = kwargs.get("constraint", None)
distribute_strategy = kwargs.get("distribute_strategy", None)
synchronization = kwargs.get("synchronization", None)
aggregation = kwargs.get("aggregation", None)
shape = kwargs.get("shape", None)
experimental_enable_variable_lifting = kwargs.get(
"experimental_enable_variable_lifting", None)
return ResourceVariable(
initial_value=initial_value,
trainable=trainable,
validate_shape=validate_shape,
caching_device=caching_device,
name=name,
dtype=dtype,
constraint=constraint,
variable_def=variable_def,
import_scope=import_scope,
distribute_strategy=distribute_strategy,
synchronization=synchronization,
aggregation=aggregation,
shape=shape,
experimental_enable_variable_lifting=experimental_enable_variable_lifting,
)
class BaseResourceVariable(variables.Variable, core.Tensor):
"""A python variable from an existing handle."""
# TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
def __init__( # pylint: disable=super-init-not-called
self,
trainable=None,
shape=None,
dtype=None,
handle=None,
constraint=None,
synchronization=None,
aggregation=None,
distribute_strategy=None,
name=None,
unique_id=None,
handle_name=None,
graph_element=None,
initial_value=None,
initializer_op=None,
is_initialized_op=None,
cached_value=None,
save_slice_info=None,
caching_device=None,
in_graph_mode=None,
validate_shape=True,
**unused_kwargs):
"""Creates a variable from a handle.
Args:
trainable: If `True`, GradientTapes automatically watch uses of this
Variable.
shape: The variable's shape. This shape can be set to tf.TensorShape(None)
in order to assign values of different shapes to this variable.
Otherwise (i.e. if the shape is fully determined), it will trigger run
time checks to ensure that each assignment is of the same shape.
dtype: The variable's dtype.
handle: The variable's handle
constraint: An optional projection function to be applied to the variable
after being updated by an `Optimizer` (e.g. used to implement norm
constraints or value constraints for layer weights). The function must
take as input the unprojected Tensor representing the value of the
variable and return the Tensor for the projected value (which must have
the same shape). Constraints are not safe to use when doing asynchronous
distributed training.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set to
`AUTO` and the current `DistributionStrategy` chooses when to
synchronize.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
`tf.VariableAggregation`.
distribute_strategy: The distribution strategy this variable was created
under.
name: The name for this variable.
unique_id: Internal. Unique ID for this variable's handle.
handle_name: The name for the variable's handle.
graph_element: Optional, required only in session.run-mode. Pre-created
tensor which reads this variable's value.
initial_value: Optional. Variable's initial value.
initializer_op: Operation which assigns the variable's initial value.
is_initialized_op: Pre-created operation to check whether this variable is
initialized.
cached_value: Pre-created operation to read this variable in a specific
device.
save_slice_info: Metadata for variable partitioning.
caching_device: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's
device. If not `None`, caches on another device. Typical use is to
cache on the device where the Ops using the Variable reside, to
deduplicate copying through `Switch` and other conditional statements.
in_graph_mode: whether we are executing in TF1 graph mode. If None, will
detect within the function. This is to avoid repeated init_scope()
conetxt entrances which can add up.
validate_shape: If `False`, allows the variable to be initialized with a
value of unknown shape. If `True`, the default, the shape of
`initial_value` must be known.
"""
if in_graph_mode is None:
with ops.init_scope():
self._in_graph_mode = not context.executing_eagerly()
else:
self._in_graph_mode = in_graph_mode
synchronization, aggregation, trainable = (
variables.validate_synchronization_aggregation_trainable(
synchronization, aggregation, trainable, name))
self._trainable = trainable
self._synchronization = synchronization
self._aggregation = aggregation
self._save_slice_info = save_slice_info
self._initial_value = initial_value
self._initializer_op = initializer_op
self._is_initialized_op = is_initialized_op
self._graph_element = graph_element
self._caching_device = caching_device
self._cached_value = cached_value
self._distribute_strategy = distribute_strategy
# Store the graph key so optimizers know how to only retrieve variables from
# this graph. Guaranteed to be the same as the eager graph_key.
self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
self._shape = tensor_shape.as_shape(shape)
self._dtype = dtypes.as_dtype(dtype)
self._handle = handle
self._unique_id = unique_id
if handle_name is None:
self._handle_name = "Variable:0"
else:
self._handle_name = handle_name + ":0"
self._constraint = constraint
self._cached_shape_as_list = None
self._validate_shape = validate_shape
self._xla_sharding = None
self._variable_read = False
def _get_xla_sharding(self):
return self._xla_sharding
def _set_xla_sharding(self, xla_sharding):
"""Annotates this `ResourceVariable` with `xla_sharding`.
`xla_sharding` will be used to create an `XlaShardingOp` whenever a
`ReadVariableOp` is created.
Args:
xla_sharding: The xla.OpSharding proto to annotate this ResourceVariable
with.
"""
if self._variable_read and not context.executing_eagerly():
logging.warning(
"This variable (%s) has already been read (ie. a ReadVariableOp has"
" already been generated) and a new XlaShardingOp using this sharding"
" will not be created unless it is read again. If that's not possible"
", please set the XLA sharding before reading the variable.",
self.name,
)
self._xla_sharding = xla_sharding
def __repr__(self):
if context.executing_eagerly() and not self._in_graph_mode:
# If we cannot read the value for any reason (e.g. variable uninitialized
# during tf.function tracing), still produce a __repr__. Note that for
# async eager, errors due to uninitialized variables will raise in
# ops.value_text when the handle is resolved, so we need to keep that
# under the try...except if we want to suppress them.
try:
with ops.device(self.device):
value_text = ops.value_text(self.read_value(), is_repr=True)
except: # pylint: disable=bare-except
value_text = "numpy=<unavailable>"
return "<tf.Variable '%s' shape=%s dtype=%s, %s>" % (
self.name, self.get_shape(), self.dtype.name, value_text)
else:
return "<tf.Variable '%s' shape=%s dtype=%s>" % (
self.name, self.get_shape(), self.dtype.name)
def __tf_tracing_type__(self, signature_context):
alias_id = signature_context.alias_global_id(self._handle._id) # pylint:disable=protected-access
# TODO(xjun): Create variable placeholders directly from VariableSpec
# without using original values.
signature_context.add_placeholder(alias_id, self)
return VariableSpec(shape=self.shape,
dtype=self.dtype,
trainable=self.trainable,
alias_id=alias_id)
@contextlib.contextmanager
def _assign_dependencies(self):
"""Makes assignments depend on the cached value, if any.
This prevents undefined behavior with reads not ordered wrt writes.
Yields:
None.
"""
if self._cached_value is not None:
with ops.control_dependencies([self._cached_value]):
yield
else:
yield
def __array__(self, dtype=None):
"""Allows direct conversion to a numpy array.
>>> np.array(tf.Variable([1.0]))
array([1.], dtype=float32)
Returns:
The variable value as a numpy array.
"""
# You can't return `self.numpy()` here because for scalars
# that raises:
# ValueError: object __array__ method not producing an array
# Even `self.read_value().__array__()` and `self.read_value()._numpy()` give
# the same error. The `EagerTensor` class must be doing something behind the
# scenes to make `np.array(tf.constant(1))` work.
return np.asarray(self.numpy(), dtype=dtype)
def __nonzero__(self):
return self.__bool__()
def __bool__(self):
return bool(self.read_value())
def __copy__(self):
return self
def __deepcopy__(self, memo):
if not context.executing_eagerly():
raise NotImplementedError(
"__deepcopy__() is only available when eager execution is enabled.")
copied_variable = ResourceVariable(
initial_value=self.read_value(),
trainable=self._trainable,
constraint=self._constraint,
dtype=self._dtype,
name=self._shared_name,
distribute_strategy=self._distribute_strategy,
synchronization=self.synchronization,
aggregation=self.aggregation)
memo[self._unique_id] = copied_variable
return copied_variable
@property
def dtype(self):
"""The dtype of this variable."""
return self._dtype
@property
def device(self):
"""The device this variable is on."""
return self.handle.device
@property
def graph(self):
"""The `Graph` of this variable."""
return self.handle.graph
@property
def name(self):
"""The name of the handle for this variable."""
return self._handle_name
@property
def shape(self):
"""The shape of this variable."""
return self._shape
def set_shape(self, shape):
self._shape = self._shape.merge_with(shape)
def _shape_as_list(self):
if self.shape.ndims is None:
return None
return [dim.value for dim in self.shape.dims]
def _shape_tuple(self):
shape = self._shape_as_list()
if shape is None:
return None
return tuple(shape)
@property
def create(self):
"""The op responsible for initializing this variable."""
if not self._in_graph_mode:
raise RuntimeError("This operation is not supported "
"when eager execution is enabled.")
return self._initializer_op
@property
def handle(self):
"""The handle by which this variable can be accessed."""
return self._handle
def value(self):
"""A cached operation which reads the value of this variable."""
if self._cached_value is not None:
return self._cached_value
with ops.colocate_with(None, ignore_existing=True):
return self._read_variable_op()
def _as_graph_element(self):
"""Conversion function for Graph.as_graph_element()."""
return self._graph_element
@property
def initializer(self):
"""The op responsible for initializing this variable."""
return self._initializer_op
@property
def initial_value(self):
"""Returns the Tensor used as the initial value for the variable."""
if context.executing_eagerly():
raise RuntimeError("This property is not supported "
"when eager execution is enabled.")
return self._initial_value
@property
def constraint(self):
"""Returns the constraint function associated with this variable.
Returns:
The constraint function that was passed to the variable constructor.
Can be `None` if no constraint was passed.
"""
return self._constraint
@property
def op(self) -> ops.Operation:
"""The op for this variable."""
return self.handle.op
@property
def trainable(self):
return self._trainable
@property
def synchronization(self):
return self._synchronization
@property
def aggregation(self):
return self._aggregation
def eval(self, session=None):
"""Evaluates and returns the value of this variable."""
if context.executing_eagerly():
raise RuntimeError("This operation is not supported "
"when eager execution is enabled.")
return self._graph_element.eval(session=session)
def numpy(self):
if context.executing_eagerly():
return self.read_value().numpy()
raise NotImplementedError(
"numpy() is only available when eager execution is enabled.")
@deprecated(None, "Prefer Dataset.range instead.")
def count_up_to(self, limit):
"""Increments this variable until it reaches `limit`.
When that Op is run it tries to increment the variable by `1`. If
incrementing the variable would bring it above `limit` then the Op raises
the exception `OutOfRangeError`.
If no error is raised, the Op outputs the value of the variable before
the increment.
This is essentially a shortcut for `count_up_to(self, limit)`.
Args:
limit: value at which incrementing the variable raises an error.
Returns:
A `Tensor` that will hold the variable value before the increment. If no
other Op modifies this variable, the values produced will all be
distinct.
"""
return gen_state_ops.resource_count_up_to(
self.handle, limit=limit, T=self.dtype)
def _copy_trackable_to_cpu(self, object_map):
"""For implementing `Trackable`."""
if self not in object_map:
# If not populated, initialize the cpu copy first.
op_device = pydev.DeviceSpec.from_string(self.device).replace(
device_type="CPU", device_index=0).to_string()
with ops.device(op_device):
# Use `op_device` to prevent cross-device communication for variables
# like `ShardedVariable`
new_var = UninitializedVariable(
trainable=self.trainable,
shape=self.shape,
dtype=self.dtype,
name=self._shared_name) # pylint: disable=protected-access
object_map[self] = new_var
# Then copy value of self to the copy.
destination_var = object_map[self]
with ops.device(destination_var.device):
# Use `op_device` to prevent cross-device communication for variables
# like `ShardedVariable`
destination_var.assign(self.read_value())
def _export_to_saved_model_graph(self, object_map=None, tensor_map=None,
options=None, **kwargs):
"""For implementing `Trackable`."""
new_variable = None
if options.experimental_variable_policy._save_variable_devices(): # pylint:disable=protected-access
with ops.device(self.device):
new_variable = copy_to_graph_uninitialized(self)
else:
new_variable = copy_to_graph_uninitialized(self)
object_map[self] = new_variable
tensor_map[self.handle] = new_variable.handle
return [self.handle]
def _serialize_to_tensors(self):
"""Implements Trackable._serialize_to_tensors."""
def _read_variable_closure():
v = self
with ops.device(v.device):
if context.executing_eagerly() and not v.is_initialized():
# A SaveSpec tensor value of `None` indicates that the variable is
# uninitialized.
return None
# Read the variable without making a copy to limit memory usage.
x = v.read_value_no_copy()
# To allow variables placed on non-CPU devices to be checkpointed,
# we copy them to CPU on the same machine first.
with ops.device("/device:CPU:0"):
return array_ops.identity(x)
return {
trackable.VARIABLE_VALUE_KEY:
tensor_callable.Callable(
_read_variable_closure, dtype=self.dtype, device=self.device)
}
def _restore_from_tensors(self, restored_tensors):
"""Implements Trackable._restore_from_tensors."""
with ops.device(self.device):
restored_tensor = array_ops.identity(
restored_tensors[trackable.VARIABLE_VALUE_KEY])
try:
assigned_variable = shape_safe_assign_variable_handle(
self.handle, self.shape, restored_tensor)
except ValueError as e:
raise ValueError(
f"Received incompatible tensor with shape {restored_tensor.shape} "
f"when attempting to restore variable with shape {self.shape} "
f"and name {self.name}.") from e
return assigned_variable
def _read_variable_op(self, no_copy=False):
"""Reads the value of the variable.
If the variable is in copy-on-read mode and `no_copy` is True, the variable
is converted to copy-on-write mode before it is read.
Args:
no_copy: Whether to prevent a copy of the variable.
Returns:
The value of the variable.
"""
variable_accessed(self)
self._variable_read = True
def read_and_set_handle(no_copy):
if no_copy and forward_compat.forward_compatible(2022, 5, 3):
gen_resource_variable_ops.disable_copy_on_read(self.handle)
result = gen_resource_variable_ops.read_variable_op(
self.handle, self._dtype)
_maybe_set_handle_data(self._dtype, self.handle, result)
return result
if getattr(self, "_caching_device", None) is not None:
with ops.colocate_with(None, ignore_existing=True):
with ops.device(self._caching_device):
result = read_and_set_handle(no_copy)
else:
result = read_and_set_handle(no_copy)
if not context.executing_eagerly():
# Note that if a control flow context is active the input of the read op
# might not actually be the handle. This line bypasses it.
record.record_operation(
"ReadVariableOp", [result], [self.handle],
backward_function=lambda x: [x],
forward_function=lambda x: [x])
# Create an XlaShardingOp if this ResourceVariable is annotated with an XLA
# sharding i.e. the _xla_sharding field is set. Please see the design at
# http://shortn/_RGoruJpzrv for more details.
if (
context.xla_sharding_for_resource_variables_enabled()
and not context.executing_eagerly()
and self._xla_sharding is not None
):
sharding_string = self._xla_sharding.SerializeToString()
result = gen_xla_ops.xla_sharding(result, sharding=sharding_string)
# pylint: disable=protected-access
result.op._set_attr(
"_XlaSharding",
attr_value_pb2.AttrValue(s=sharding_string),
)
return result
def read_value(self):
"""Constructs an op which reads the value of this variable.
Should be used when there are multiple reads, or when it is desirable to
read the value only after some condition is true.
Returns:
The value of the variable.
"""
with ops.name_scope("Read"):
value = self._read_variable_op()
# Return an identity so it can get placed on whatever device the context
# specifies instead of the device where the variable is.
return array_ops.identity(value)
def read_value_no_copy(self):
"""Constructs an op which reads the value of this variable without copy.
The variable is read without making a copy even when it has been sparsely
accessed. Variables in copy-on-read mode will be converted to copy-on-write
mode.
Returns:
The value of the variable.
"""
with ops.name_scope("Read"):
value = self._read_variable_op(no_copy=True)
# Return an identity so it can get placed on whatever device the context
# specifies instead of the device where the variable is.
return array_ops.identity(value)
def sparse_read(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
variable_accessed(self)
value = gen_resource_variable_ops.resource_gather(
self.handle, indices, dtype=self._dtype, name=name)
if self._dtype == dtypes.variant:
# For DT_VARIANT types, the handle's shape_and_type[1:] stores the
# variant's handle data. Extract it.
handle_data = get_eager_safe_handle_data(self.handle)
if handle_data.is_set and len(handle_data.shape_and_type) > 1:
value._handle_data = ( # pylint: disable=protected-access
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
return array_ops.identity(value)
return value
def gather_nd(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather_nd`."""
with ops.name_scope("GatherNd" if name is None else name) as name:
if self.trainable:
variable_accessed(self)
value = gen_resource_variable_ops.resource_gather_nd(
self.handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value)
def to_proto(self, export_scope=None):
"""Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Raises:
RuntimeError: If run in EAGER mode.
Returns:
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
in the specified name scope.
"""
if context.executing_eagerly():
raise RuntimeError("This operation is not supported "
"when eager execution is enabled.")
if export_scope is None or self.handle.name.startswith(export_scope):
var_def = variable_pb2.VariableDef()
var_def.variable_name = ops.strip_name_scope(self.handle.name,
export_scope)
if self._initial_value is not None:
# This is inside an if-statement for backwards compatibility, since
# self._initial_value might be None for variables constructed from old
# protos.
var_def.initial_value_name = ops.strip_name_scope(
self._initial_value.name, export_scope)
var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
export_scope)
if self._cached_value is not None:
var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
export_scope)
else:
# Store the graph_element here
var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
export_scope)
var_def.is_resource = True
var_def.trainable = self.trainable
var_def.synchronization = self.synchronization.value
var_def.aggregation = self.aggregation.value
if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(
self._save_slice_info.to_proto(export_scope=export_scope))
return var_def
else:
return None
@staticmethod
def from_proto(variable_def, import_scope=None):
if context.executing_eagerly():
raise RuntimeError("This operation is not supported "
"when eager execution is enabled.")
return ResourceVariable(
variable_def=variable_def, import_scope=import_scope)
__array_priority__ = 100
def is_initialized(self, name=None):
"""Checks whether a resource variable has been initialized.
Outputs boolean scalar indicating whether the tensor has been initialized.
Args:
name: A name for the operation (optional).
Returns:
A `Tensor` of type `bool`.
"""
return gen_resource_variable_ops.var_is_initialized_op(self.handle, name)
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
"""Subtracts a value from this variable.
Args: