/
stateful_random_ops.py
1029 lines (860 loc) · 38.9 KB
/
stateful_random_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations for generating random numbers."""
import six
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateful_random_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.stateless_random_ops import Algorithm
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# A seed for random ops (stateful and stateless) will always be 1024
# bits, all of which will be sent to the C++ code. The actual C++
# implementation of some algorithms may only use a lower part of the bits.
UINT64_HALF_SPAN = 2**63
MAX_INT64 = UINT64_HALF_SPAN - 1
MIN_INT64 = -UINT64_HALF_SPAN
UINT64_SPAN = UINT64_HALF_SPAN * 2
# 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained in
# b/111604096 and cl/171681867), so I use signed int here. I choose int64
# instead of int32 here because `VarHandleOp` doesn't support int32 on GPU.
SEED_TYPE = "int64"
SEED_MIN = MIN_INT64
SEED_MAX = MAX_INT64
SEED_UINT_SPAN = UINT64_SPAN
SEED_TYPE_BITS = 64
SEED_BIT_MASK = 0xFFFFFFFFFFFFFFFF
SEED_SIZE = 16 # in units of SEED_TYPE
STATE_TYPE = SEED_TYPE
ALGORITHM_TYPE = STATE_TYPE
PHILOX_STATE_SIZE = 3
THREEFRY_STATE_SIZE = 2
RNG_ALG_PHILOX = Algorithm.PHILOX.value
RNG_ALG_THREEFRY = Algorithm.THREEFRY.value
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
def non_deterministic_ints(shape, dtype=dtypes.int64):
"""Non-deterministically generates some integers.
This op may use some OS-provided source of non-determinism (e.g. an RNG), so
each execution will give different results.
Args:
shape: the shape of the result.
dtype: (optional) the dtype of the result.
Returns:
a tensor whose element values are non-deterministically chosen.
"""
return gen_stateful_random_ops.non_deterministic_ints(
shape=shape, dtype=dtype)
def _uint_to_int(n):
if isinstance(n, int) and n > SEED_MAX:
n = n - SEED_UINT_SPAN
return n
def _make_1d_state(state_size, seed):
"""Makes a 1-D RNG state.
Args:
state_size: an integer.
seed: an integer or 1-D tensor.
Returns:
a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
"""
if isinstance(seed, six.integer_types):
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
ls = []
for _ in range(state_size):
ls.append(seed & SEED_BIT_MASK)
seed >>= SEED_TYPE_BITS
seed = ls
# to avoid overflow error from ops.convert_to_tensor
seed = nest.map_structure(_uint_to_int, seed)
seed = math_ops.cast(seed, STATE_TYPE)
seed = array_ops.reshape(seed, [-1])
seed = seed[0:state_size]
# Padding with zeros on the *left* if too short. Padding on the right would
# cause a small seed to be used as the "counter" while the "key" is always
# zero (for counter-based RNG algorithms), because in the current memory
# layout counter is stored before key. In such a situation two RNGs with
# two different small seeds may generate overlapping outputs.
seed_size = seed.shape[0]
if seed_size is None:
seed_size = array_ops.shape(seed)[0]
padding_size = math_ops.maximum(state_size - seed_size, 0)
padding = array_ops.zeros([padding_size], seed.dtype)
# can't use `pad` because it doesn't support integer dtypes on GPU
seed = array_ops.concat([padding, seed], axis=0)
seed.set_shape([state_size])
return seed
def _get_counter_size(alg):
if alg == RNG_ALG_PHILOX:
return 2
elif alg == RNG_ALG_THREEFRY:
return 1
else:
raise ValueError(
f"Argument `alg` got unsupported value {alg}. Supported values are "
f"{RNG_ALG_PHILOX} for the Philox algorithm and {RNG_ALG_THREEFRY} for "
f"the ThreeFry algorithm.")
def _get_state_size(alg):
if alg == RNG_ALG_PHILOX:
return PHILOX_STATE_SIZE
elif alg == RNG_ALG_THREEFRY:
return THREEFRY_STATE_SIZE
else:
raise ValueError(
f"Argument `alg` got unsupported value {alg}. Supported values are "
f"{RNG_ALG_PHILOX} for the Philox algorithm and {RNG_ALG_THREEFRY} for "
f"the ThreeFry algorithm.")
def _check_state_shape(shape, alg):
if isinstance(alg, ops.Tensor) and not context.executing_eagerly():
return
shape.assert_is_compatible_with([_get_state_size(int(alg))])
def _make_state_from_seed(seed, alg):
return _make_1d_state(_get_state_size(alg), seed)
@tf_export("random.create_rng_state", "random.experimental.create_rng_state")
def create_rng_state(seed, alg):
"""Creates a RNG state from an integer or a vector.
Example:
>>> tf.random.create_rng_state(
... 1234, "philox")
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1234, 0, 0])>
>>> tf.random.create_rng_state(
... [12, 34], "threefry")
<tf.Tensor: shape=(2,), dtype=int64, numpy=array([12, 34])>
Args:
seed: an integer or 1-D numpy array.
alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer.
Returns:
a 1-D numpy array whose size depends on the algorithm.
"""
alg = stateless_random_ops.convert_alg_to_int(alg)
return _make_state_from_seed(seed, alg)
def _shape_tensor(shape):
"""Convert to an int32 or int64 tensor, defaulting to int64 if empty."""
if isinstance(shape, (tuple, list)) and not shape:
dtype = dtypes.int64
else:
dtype = None
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
def _convert_to_state_tensor(t):
# to avoid out-of-range error from ops.convert_to_tensor
t = nest.map_structure(_uint_to_int, t)
return math_ops.cast(t, STATE_TYPE)
def get_replica_id():
rctx = ds_context.get_replica_context()
if rctx is None:
return None
return rctx.replica_id_in_sync_group
@tf_export("random.Generator", "random.experimental.Generator")
class Generator(tracking.AutoTrackable):
"""Random-number generator.
Example:
Creating a generator from a seed:
>>> g = tf.random.Generator.from_seed(1234)
>>> g.normal(shape=(2, 3))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.9356609 , 1.0854305 , -0.93788373],
[-0.5061547 , 1.3169702 , 0.7137579 ]], dtype=float32)>
Creating a generator from a non-deterministic state:
>>> g = tf.random.Generator.from_non_deterministic_state()
>>> g.normal(shape=(2, 3))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
All the constructors allow explicitly choosing an Random-Number-Generation
(RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For
example:
>>> g = tf.random.Generator.from_seed(123, alg="philox")
>>> g.normal(shape=(2, 3))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.8673864 , -0.29899067, -0.9310337 ],
[-1.5828488 , 1.2481191 , -0.6770643 ]], dtype=float32)>
CPU, GPU and TPU with the same algorithm and seed will generate the same
integer random numbers. Float-point results (such as the output of `normal`)
may have small numerical discrepancies between different devices.
This class uses a `tf.Variable` to manage its internal state. Every time
random numbers are generated, the state of the generator will change. For
example:
>>> g = tf.random.Generator.from_seed(1234)
>>> g.state
<tf.Variable ... numpy=array([1234, 0, 0])>
>>> g.normal(shape=(2, 3))
<...>
>>> g.state
<tf.Variable ... numpy=array([2770, 0, 0])>
The shape of the state is algorithm-specific.
There is also a global generator:
>>> g = tf.random.get_global_generator()
>>> g.normal(shape=(2, 3))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
When creating a generator inside a `tf.distribute.Strategy` scope, each
replica will get a different stream of random numbers.
For example, in this code:
```
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
g = tf.random.Generator.from_seed(1)
def f():
return g.normal([])
results = strat.run(f).values
```
`results[0]` and `results[1]` will have different values.
If the generator is seeded (e.g. created via `Generator.from_seed`), the
random numbers will be determined by the seed, even though different replicas
get different numbers. One can think of a random number generated on a
replica as a hash of the replica ID and a "master" random number that may be
common to all replicas. Hence, the whole system is still deterministic.
(Note that the random numbers on different replicas are not correlated, even
if they are deterministically determined by the same seed. They are not
correlated in the sense that no matter what statistics one calculates on them,
there won't be any discernable correlation.)
Generators can be freely saved and restored using `tf.train.Checkpoint`. The
checkpoint can be restored in a distribution strategy with a different number
of replicas than the original strategy. If a replica ID is present in both the
original and the new distribution strategy, its state will be properly
restored (i.e. the random-number stream from the restored point will be the
same as that from the saving point) unless the replicas have already diverged
in their RNG call traces before saving (e.g. one replica has made one RNG call
while another has made two RNG calls). We don't have such guarantee if the
generator is saved in a strategy scope and restored outside of any strategy
scope, or vice versa.
When a generator is created within the scope of
`tf.distribute.experimental.ParameterServerStrategy`, the workers
will share the generator's state (placed on one of the parameter
servers). In this way the workers will still get different
random-number streams, as stated above. (This is similar to replicas
in a `tf.distribute.MirroredStrategy` sequentially accessing a
generator created outside the strategy.) Each RNG call on a worker
will incur a round-trip to a parameter server, which may have
performance impacts. When creating a
`tf.distribute.experimental.ParameterServerStrategy`, please make
sure that the `variable_partitioner` argument won't shard small
variables of shape `[2]` or `[3]` (because generator states must not
be sharded). Ways to avoid sharding small variables include setting
`variable_partitioner` to `None` or to
`tf.distribute.experimental.partitioners.MinSizePartitioner` with a
large enough `min_shard_bytes` (see
`tf.distribute.experimental.ParameterServerStrategy`'s documentation
for more details).
"""
@classmethod
def from_state(cls, state, alg):
"""Creates a generator from a state.
See `__init__` for description of `state` and `alg`.
Args:
state: the new state.
alg: the RNG algorithm.
Returns:
The new generator.
"""
return cls(alg=alg, state=state)
@classmethod
def from_seed(cls, seed, alg=None):
"""Creates a generator from a seed.
A seed is a 1024-bit unsigned integer represented either as a Python
integer or a vector of integers. Seeds shorter than 1024-bit will be
padded. The padding, the internal structure of a seed and the way a seed
is converted to a state are all opaque (unspecified). The only semantics
specification of seeds is that two different seeds are likely to produce
two independent generators (but no guarantee).
Args:
seed: the seed for the RNG.
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
"""
if alg is None:
# TODO(b/170668986): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
alg = stateless_random_ops.convert_alg_to_int(alg)
state = create_rng_state(seed, alg)
return cls(state=state, alg=alg)
@classmethod
def from_non_deterministic_state(cls, alg=None):
"""Creates a generator by non-deterministically initializing its state.
The source of the non-determinism will be platform- and time-dependent.
Args:
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
"""
if config.is_op_determinism_enabled():
raise RuntimeError('"from_non_deterministic_state" cannot be called when ' # pylint: disable=g-doc-exception
"determinism is enabled.")
if alg is None:
# TODO(b/170668986): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
alg = stateless_random_ops.convert_alg_to_int(alg)
state = non_deterministic_ints(shape=[_get_state_size(alg)],
dtype=SEED_TYPE)
return cls(state=state, alg=alg)
@classmethod
def from_key_counter(cls, key, counter, alg):
"""Creates a generator from a key and a counter.
This constructor only applies if the algorithm is a counter-based algorithm.
See method `key` for the meaning of "key" and "counter".
Args:
key: the key for the RNG, a scalar of type STATE_TYPE.
counter: a vector of dtype STATE_TYPE representing the initial counter for
the RNG, whose length is algorithm-specific.,
alg: the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
"""
counter = _convert_to_state_tensor(counter)
key = _convert_to_state_tensor(key)
alg = stateless_random_ops.convert_alg_to_int(alg)
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
key.shape.assert_is_compatible_with([])
key = array_ops.reshape(key, [1])
state = array_ops.concat([counter, key], 0)
return cls(state=state, alg=alg)
def __init__(self, copy_from=None, state=None, alg=None):
"""Creates a generator.
The new generator will be initialized by one of the following ways, with
decreasing precedence:
(1) If `copy_from` is not None, the new generator is initialized by copying
information from another generator.
(2) If `state` and `alg` are not None (they must be set together), the new
generator is initialized by a state.
Args:
copy_from: a generator to be copied from.
state: a vector of dtype STATE_TYPE representing the initial state of the
RNG, whose length and semantics are algorithm-specific. If it's a
variable, the generator will reuse it instead of creating a new
variable.
alg: the RNG algorithm. Possible values are
`tf.random.Algorithm.PHILOX` for the Philox algorithm and
`tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm
(see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
The string names `"philox"` and `"threefry"` can also be used.
Note `PHILOX` guarantees the same numbers are produced (given
the same random state) across all architectures (CPU, GPU, XLA etc).
"""
# TODO(b/175072242): Remove distribution-strategy dependencies in this file.
if ds_context.has_strategy():
self._distribution_strategy = ds_context.get_strategy()
else:
self._distribution_strategy = None
if copy_from is not None:
# All other arguments should be None
assert (alg or state) is None
self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
trainable=False)
self._alg = copy_from.algorithm
else:
assert alg is not None and state is not None
alg = stateless_random_ops.convert_alg_to_int(alg)
if isinstance(state, variables.Variable):
_check_state_shape(state.shape, alg)
self._state_var = state
else:
state = _convert_to_state_tensor(state)
_check_state_shape(state.shape, alg)
self._state_var = self._create_variable(state, dtype=STATE_TYPE,
trainable=False)
self._alg = alg
def _create_variable(self, *args, **kwargs):
"""Creates a variable.
Args:
*args: positional arguments passed along to `variables.Variable.
**kwargs: keyword arguments passed along to `variables.Variable.
Returns:
The created variable.
"""
with ops.name_scope("random_generator"):
# Make sure we don't change this name since Keras was using this name
# to filter out the state variable.
kwargs["name"] = "StateVar"
v = variables.Variable(*args, **kwargs)
if isinstance(v, sharded_variable.ShardedVariable):
# RNG state is an atomic entity representing a 128-bit or
# 192-bit value, so it mustn't be sharded.
raise ValueError(
"tf.random.Generator state is sharded, which is not allowed. When "
"creating a tf.distribute.experimental.ParameterServerStrategy, "
"please make sure that the `variable_partitioner` "
"argument won't shard a "
"small variable of shape [2] or [3]. Ways to avoid sharding small "
"variables include setting `variable_partitioner` to None or to "
"tf.distribute.experimental.partitioners.MinSizePartitioner with a "
"large enough `min_shard_bytes`.")
return v
def reset(self, state):
"""Resets the generator by a new state.
See `__init__` for the meaning of "state".
Args:
state: the new state.
"""
state = _convert_to_state_tensor(state)
state.shape.assert_is_compatible_with([_get_state_size(self.algorithm)])
self._state_var.assign(state)
def reset_from_seed(self, seed):
"""Resets the generator by a new seed.
See `from_seed` for the meaning of "seed".
Args:
seed: the new seed.
"""
state = create_rng_state(seed, self.algorithm)
self._state_var.assign(state)
def reset_from_key_counter(self, key, counter):
"""Resets the generator by a new key-counter pair.
See `from_key_counter` for the meaning of "key" and "counter".
Args:
key: the new key.
counter: the new counter.
"""
counter = _convert_to_state_tensor(counter)
key = _convert_to_state_tensor(key)
counter.shape.assert_is_compatible_with(
[_get_state_size(self.algorithm) - 1])
key.shape.assert_is_compatible_with([])
key = array_ops.reshape(key, [1])
state = array_ops.concat([counter, key], 0)
self._state_var.assign(state)
@property
def state(self):
"""The internal state of the RNG."""
return self._state_var
@property
def algorithm(self):
"""The RNG algorithm id (a Python integer or scalar integer Tensor)."""
return self._alg
def _standard_normal(self, shape, dtype):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
@property
def key(self):
"""The 'key' part of the state of a counter-based RNG.
For a counter-base RNG algorithm such as Philox and ThreeFry (as
described in paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]),
the RNG state consists of two parts: counter and key. The output is
generated via the formula: output=hash(key, counter), i.e. a hashing of
the counter parametrized by the key. Two RNGs with two different keys can
be thought as generating two independent random-number streams (a stream
is formed by increasing the counter).
Returns:
A scalar which is the 'key' part of the state, if the RNG algorithm is
counter-based; otherwise it raises a ValueError.
"""
alg = self.algorithm
if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
return self._state_var[-1]
else:
raise ValueError(
f"This generator uses an unsupported algorithm {alg}. Supported "
f"values are {RNG_ALG_PHILOX} for the Philox algorithm and "
f"{RNG_ALG_THREEFRY} for the ThreeFry algorithm.")
def _skip_single_var(self, var, delta):
# TODO(wangpeng): Cache the cast algorithm instead of casting everytime.
return gen_stateful_random_ops.rng_read_and_skip(
var.handle,
alg=math_ops.cast(self.algorithm, dtypes.int32),
delta=math_ops.cast(delta, dtypes.uint64))
def skip(self, delta):
"""Advance the counter of a counter-based RNG.
Args:
delta: the amount of advancement. The state of the RNG after
`skip(n)` will be the same as that after `normal([n])`
(or any other distribution). The actual increment added to the
counter is an unspecified implementation detail.
Returns:
A `Tensor` of type `int64`.
"""
def update_fn(v):
return self._skip_single_var(v, delta)
# TODO(b/170515001): Always call strategy.extended.update after calling it
# from both replica context and cross-replica context is supported.
if values_util.is_saving_non_distributed():
# Assumes replica context with replica_id=0, since we only save the first
# replica.
return update_fn(self.state)
if self._distribution_strategy is not None:
with ds_context.enter_or_assert_strategy(self._distribution_strategy):
if ds_context.in_cross_replica_context():
# Code that operates on all replicas of a variable cannot be saved
# without retracing.
values_util.mark_as_unsaveable()
if (ds_context.in_cross_replica_context() or
"CentralStorage" in type(self._distribution_strategy).__name__):
# In cross-replica context we need to use strategy.extended.update.
# In CentralStorageStrategy we also need to use
# strategy.extended.update (even for replica context),
# because variable updates here must be within merge_call.
return ds_context.get_strategy().extended.update(
self.state, update_fn)
return update_fn(self.state)
def _preprocess_key(self, key):
if self._distribution_strategy is None:
return key
with ds_context.enter_or_assert_strategy(self._distribution_strategy):
replica_id = get_replica_id()
if replica_id is not None:
replica_id = array_ops.stack([replica_id, 0], axis=0)
replica_id = math_ops.cast(replica_id, dtypes.uint64)
# Conceptually: key = hash(key, replica_id)
key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64,
alg=self.algorithm)
return key
def _prepare_key_counter(self, shape):
delta = math_ops.reduce_prod(shape)
counter_key = self.skip(delta)
counter_size = _get_counter_size(self.algorithm)
counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
dtypes.uint64)
key = self._preprocess_key(key)
return key, counter
# The following functions return a tensor and as a side effect update
# self._state_var.
def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
name=None):
"""Outputs random values from a normal distribution.
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output
tensor.
mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
distribution.
stddev: A 0-D Tensor or Python value of type `dtype`. The standard
deviation of the normal distribution.
dtype: The type of the output.
name: A name for the operation (optional).
Returns:
A tensor of the specified shape filled with random normal values.
"""
with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name:
shape = _shape_tensor(shape)
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
rnd = self._standard_normal(shape, dtype=dtype)
return math_ops.add(rnd * stddev, mean, name=name)
def _truncated_normal(self, shape, dtype):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
def truncated_normal(self, shape,
mean=0.0,
stddev=1.0,
dtype=dtypes.float32,
name=None):
"""Outputs random values from a truncated normal distribution.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than
2 standard deviations from the mean are dropped and re-picked.
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output
tensor.
mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
truncated normal distribution.
stddev: A 0-D Tensor or Python value of type `dtype`. The standard
deviation of the normal distribution, before truncation.
dtype: The type of the output.
name: A name for the operation (optional).
Returns:
A tensor of the specified shape filled with random truncated normal
values.
"""
with ops.name_scope(
name, "truncated_normal", [shape, mean, stddev]) as name:
shape_tensor = _shape_tensor(shape)
mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
rnd = self._truncated_normal(shape_tensor, dtype=dtype)
mul = rnd * stddev_tensor
return math_ops.add(mul, mean_tensor, name=name)
def _uniform(self, shape, dtype):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
def _uniform_full_int(self, shape, dtype, name=None):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm,
name=name)
def uniform(self, shape, minval=0, maxval=None,
dtype=dtypes.float32, name=None):
"""Outputs random values from a uniform distribution.
The generated values follow a uniform distribution in the range
`[minval, maxval)`. The lower bound `minval` is included in the range, while
the upper bound `maxval` is excluded. (For float numbers especially
low-precision types like bfloat16, because of
rounding, the result may sometimes include `maxval`.)
For floats, the default range is `[0, 1)`. For ints, at least `maxval` must
be specified explicitly.
In the integer case, the random integers are slightly biased unless
`maxval - minval` is an exact power of two. The bias is small for values of
`maxval - minval` significantly smaller than the range of the output (either
`2**32` or `2**64`).
For full-range random integers, pass `minval=None` and `maxval=None` with an
integer `dtype` (for integer dtypes, `minval` and `maxval` must be both
`None` or both not `None`).
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output
tensor.
minval: A Tensor or Python value of type `dtype`, broadcastable with
`shape` (for integer types, broadcasting is not supported, so it needs
to be a scalar). The lower bound (included) on the range of random
values to generate. Pass `None` for full-range integers. Defaults to 0.
maxval: A Tensor or Python value of type `dtype`, broadcastable with
`shape` (for integer types, broadcasting is not supported, so it needs
to be a scalar). The upper bound (excluded) on the range of random
values to generate. Pass `None` for full-range integers. Defaults to 1
if `dtype` is floating point.
dtype: The type of the output.
name: A name for the operation (optional).
Returns:
A tensor of the specified shape filled with random uniform values.
Raises:
ValueError: If `dtype` is integral and `maxval` is not specified.
"""
dtype = dtypes.as_dtype(dtype)
if dtype.is_integer:
if (minval is None) != (maxval is None):
raise ValueError("For integer dtype {}, minval and maxval must be both "
"`None` or both non-`None`; got minval={} and "
"maxval={}".format(dtype, minval, maxval))
elif maxval is None:
maxval = 1
with ops.name_scope(name, "stateful_uniform",
[shape, minval, maxval]) as name:
shape = _shape_tensor(shape)
if dtype.is_integer and minval is None:
return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
if dtype.is_integer:
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
shape=shape,
key=key,
counter=counter,
minval=minval,
maxval=maxval,
alg=self.algorithm,
name=name)
else:
rnd = self._uniform(shape=shape, dtype=dtype)
return math_ops.add(rnd * (maxval - minval), minval, name=name)
def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
"""Uniform distribution on an integer type's entire range.
This method is the same as setting `minval` and `maxval` to `None` in the
`uniform` method.
Args:
shape: the shape of the output.
dtype: (optional) the integer type, default to uint64.
name: (optional) the name of the node.
Returns:
A tensor of random numbers of the required shape.
"""
dtype = dtypes.as_dtype(dtype)
with ops.name_scope(name, "stateful_uniform_full_int",
[shape]) as name:
shape = _shape_tensor(shape)
return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None):
"""Outputs random values from a binomial distribution.
The generated values follow a binomial distribution with specified count and
probability of success parameters.
Example:
```python
counts = [10., 20.]
# Probability of success.
probs = [0.8]
rng = tf.random.Generator.from_seed(seed=234)
binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs)
counts = ... # Shape [3, 1, 2]
probs = ... # Shape [1, 4, 2]
shape = [3, 4, 3, 4, 2]
rng = tf.random.Generator.from_seed(seed=1717)
# Sample shape will be [3, 4, 3, 4, 2]
binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs)
```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output
tensor.
counts: Tensor. The counts of the binomial distribution. Must be
broadcastable with `probs`, and broadcastable with the rightmost
dimensions of `shape`.
probs: Tensor. The probability of success for the
binomial distribution. Must be broadcastable with `counts` and
broadcastable with the rightmost dimensions of `shape`.
dtype: The type of the output. Default: tf.int32
name: A name for the operation (optional).
Returns:
samples: A Tensor of the specified shape filled with random binomial
values. For each i, each samples[i, ...] is an independent draw from
the binomial distribution on counts[i] trials with probability of
success probs[i].
"""
dtype = dtypes.as_dtype(dtype)
with ops.name_scope(name, "binomial", [shape, counts, probs]) as name:
counts = ops.convert_to_tensor(counts, name="counts")
probs = ops.convert_to_tensor(probs, name="probs")
shape_tensor = _shape_tensor(shape)
return gen_stateful_random_ops.stateful_random_binomial(
self.state.handle,
self.algorithm,
shape=shape_tensor,
counts=counts,
probs=probs,
dtype=dtype,
name=name)
# TODO(wangpeng): implement other distributions
def _make_int64_keys(self, shape=()):
# New independent keys are generated via
# `new_key[i] = hash(old_key, counter+i)`, which is exactly what
# `uniform_full_int(dtype=int64)` does for PhiloxRandom_64_128_128 and
# ThreeFry_64_64_64.
return self.uniform_full_int(shape=shape, dtype=dtypes.int64)
def make_seeds(self, count=1):
"""Generates seeds for stateless random ops.
For example:
```python
seeds = get_global_generator().make_seeds(count=10)
for i in range(10):
seed = seeds[:, i]
numbers = stateless_random_normal(shape=[2, 3], seed=seed)
...
```
Args:
count: the number of seed pairs (note that stateless random ops need a
pair of seeds to invoke).
Returns:
A tensor of shape [2, count] and dtype int64.
"""
alg = self.algorithm
if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
keys = self._make_int64_keys(shape=[count])
# The two seeds for stateless random ops don't have individual semantics
# and are scrambled together, so setting one to zero is fine.
zeros = array_ops.zeros_like(keys)
return array_ops.stack([keys, zeros])
else:
raise ValueError(
f"This generator uses an unsupported algorithm {alg}. Supported "
f"values are {RNG_ALG_PHILOX} for the Philox algorithm and "
f"{RNG_ALG_THREEFRY} for the ThreeFry algorithm.")
def split(self, count=1):
"""Returns a list of independent `Generator` objects.
Two generators are independent of each other in the sense that the
random-number streams they generate don't have statistically detectable
correlations. The new generators are also independent of the old one.
The old generator's state will be changed (like other random-number
generating methods), so two calls of `split` will return different
new generators.
For example:
```python
gens = get_global_generator().split(count=10)
for gen in gens:
numbers = gen.normal(shape=[2, 3])
# ...
gens2 = get_global_generator().split(count=10)
# gens2 will be different from gens
```
The new generators will be put on the current device (possible different
from the old generator's), for example:
```python
with tf.device("/device:CPU:0"):
gen = Generator(seed=1234) # gen is on CPU
with tf.device("/device:GPU:0"):
gens = gen.split(count=10) # gens are on GPU
```
Args:
count: the number of generators to return.
Returns:
A list (length `count`) of `Generator` objects independent of each other.
The new generators have the same RNG algorithm as the old one.
"""
def _key_to_state(alg, key):
# Padding with zeros on the left. The zeros will be the counter.
return [0] * (_get_state_size(alg) - 1) + [key]
alg = self.algorithm
if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
keys = self._make_int64_keys(shape=[count])
return [Generator(state=_key_to_state(alg, key), alg=alg)
for key in array_ops.unstack(keys, num=count)]
else:
raise ValueError(
f"This generator uses an unsupported algorithm {alg}. Supported "
f"values are {RNG_ALG_PHILOX} for the Philox algorithm and "
f"{RNG_ALG_THREEFRY} for the ThreeFry algorithm.")
# It's not safe to create TF ops before `init_google` is called, so this is
# initialized to None and get a value the first time `get_global_generator` is
# called.
global_generator = None
@tf_export("random.get_global_generator",
"random.experimental.get_global_generator")
def get_global_generator():
"""Retrieves the global generator.
This function will create the global generator the first time it is called,
and the generator will be placed at the default device at that time, so one
needs to be careful when this function is first called. Using a generator
placed on a less-ideal device will incur performance regression.
Returns:
The global `tf.random.Generator` object.
"""
global global_generator
if global_generator is None:
if config.is_op_determinism_enabled():
raise RuntimeError('"get_global_generator" cannot be called if ' # pylint: disable=g-doc-exception
"determinism is enabled, unless "
'"set_global_generator" has already been called. '
'Please call "set_global_generator" first.')
with ops.init_scope():
global_generator = Generator.from_non_deterministic_state()
return global_generator
@tf_export("random.set_global_generator",
"random.experimental.set_global_generator")
def set_global_generator(generator):