/
stateless_random_ops.py
921 lines (803 loc) · 37.5 KB
/
stateless_random_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Stateless random ops which take seed as a tensor input."""
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import gen_random_index_shuffle_ops
from tensorflow.python.ops import gen_stateless_random_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops_util
from tensorflow.python.ops import shape_util
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable("StatelessMultinomial")
ops.NotDifferentiable("StatelessRandomBinomial")
ops.NotDifferentiable("StatelessRandomNormal")
ops.NotDifferentiable("StatelessRandomPoisson")
ops.NotDifferentiable("StatelessRandomUniform")
ops.NotDifferentiable("StatelessRandomUniformInt")
ops.NotDifferentiable("StatelessRandomUniformFullInt")
ops.NotDifferentiable("StatelessTruncatedNormal")
ops.NotDifferentiable("StatelessRandomNormalV2")
ops.NotDifferentiable("StatelessRandomUniformV2")
ops.NotDifferentiable("StatelessRandomUniformIntV2")
ops.NotDifferentiable("StatelessRandomUniformFullIntV2")
ops.NotDifferentiable("StatelessTruncatedNormalV2")
ops.NotDifferentiable("StatelessRandomShuffle")
ops.NotDifferentiable("RandomIndexShuffle")
@tf_export("random.split", "random.experimental.stateless_split")
@dispatch.add_dispatch_support
def split(seed, num=2, alg="auto_select"):
"""Splits an RNG seed into `num` new seeds by adding a leading axis.
Example:
>>> seed = [1, 2]
>>> new_seeds = tf.random.split(seed, num=3)
>>> print(new_seeds)
tf.Tensor(
[[1105988140 1738052849]
[-335576002 370444179]
[ 10670227 -246211131]], shape=(3, 2), dtype=int32)
>>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :])
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.59835213, -0.9578608 ,
0.9002807 ], dtype=float32)>
Args:
seed: an RNG seed (a tensor with shape [2] and dtype `int32` or `int64`).
(When using XLA, only `int32` is allowed.)
num: optional, a positive integer or scalar tensor indicating the number of
seeds to produce (default 2).
alg: The RNG algorithm used to generate the random numbers. See
`tf.random.stateless_uniform` for a detailed explanation.
Returns:
A tensor with shape [num, 2] representing `num` new seeds. It will have the
same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype
will be determined by `tf.convert_to_tensor`).
"""
seed = ops.convert_to_tensor(seed)
return stateless_random_uniform(
shape=[num, 2],
seed=seed,
dtype=seed.dtype,
minval=None,
maxval=None,
alg=alg,
)
@tf_export("random.fold_in", "random.experimental.stateless_fold_in")
@dispatch.add_dispatch_support
def fold_in(seed, data, alg="auto_select"):
"""Folds in data to an RNG seed to form a new RNG seed.
For example, in a distributed-training setting, suppose we have a master seed
and a replica ID. We want to fold the replica ID into the master seed to
form a "replica seed" to be used by that replica later on, so that different
replicas will generate different random numbers but the reproducibility of the
whole system can still be controlled by the master seed:
>>> master_seed = [1, 2]
>>> replica_id = 3
>>> replica_seed = tf.random.experimental.stateless_fold_in(
... master_seed, replica_id)
>>> print(replica_seed)
tf.Tensor([1105988140 3], shape=(2,), dtype=int32)
>>> tf.random.stateless_normal(shape=[3], seed=replica_seed)
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.03197195, 0.8979765 ,
0.13253039], dtype=float32)>
Args:
seed: an RNG seed (a tensor with shape [2] and dtype `int32` or `int64`).
(When using XLA, only `int32` is allowed.)
data: an `int32` or `int64` scalar representing data to be folded in to the
seed.
alg: The RNG algorithm used to generate the random numbers. See
`tf.random.stateless_uniform` for a detailed explanation.
Returns:
A new RNG seed that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values. It
will have the same dtype as `data` (if `data` doesn't have an explict dtype,
the dtype will be determined by `tf.convert_to_tensor`).
"""
data = ops.convert_to_tensor(data)
seed1 = stateless_random_uniform(
shape=[], seed=seed, dtype=data.dtype, minval=None, maxval=None, alg=alg
)
return array_ops_stack.stack([seed1, data])
@tf_export("random.experimental.index_shuffle")
@dispatch.add_dispatch_support
def index_shuffle(index, seed, max_index):
"""Outputs the position of `index` in a permutation of `[0, ..., max_index]`.
For each possible `seed` and `max_index` there is one pseudorandom
permutation of the sequence `S=[0, ..., max_index]`. Instead of
materializing the full array we can compute the new position of any
integer `i` (`0 <= i <= max_index`) in `S`. This can be useful for
very large `max_index`s by avoiding allocating large chunks of
memory.
In the simplest case, `index` and `max_index` are scalars, and
`seed` is a length-2 vector (as typical for stateless RNGs). But
you can add a leading batch dimension to all of them. If some of
them don't have the batch dimension while others do, `index_shuffle`
will add a batch dimension to the former by broadcasting.
The input `index` and output can be used as indices to shuffle a
vector. For example:
>>> vector = tf.constant(['e0', 'e1', 'e2', 'e3'])
>>> indices = tf.random.experimental.index_shuffle(
... index=tf.range(4), seed=[5, 9], max_index=3)
>>> print(indices)
tf.Tensor([2 0 1 3], shape=(4,), dtype=int32)
>>> shuffled_vector = tf.gather(vector, indices)
>>> print(shuffled_vector)
tf.Tensor([b'e2' b'e0' b'e1' b'e3'], shape=(4,), dtype=string)
More usefully, it can be used in a streaming (aka online) scenario such as
`tf.data`, where each element of `vector` is processed individually and the
whole `vector` is never materialized in memory.
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.map(
... lambda idx: tf.random.experimental.index_shuffle(idx, [5, 8], 9))
>>> print(list(dataset.as_numpy_iterator()))
[3, 8, 0, 1, 2, 7, 6, 9, 4, 5]
This operation is stateless (like the `tf.random.stateless_*`
functions), meaning the output is fully determined by the `seed`
(other inputs being equal). Each `seed` choice corresponds to one
permutation, so when calling this function multiple times for the
same shuffling, please make sure to use the same `seed`. For
example:
>>> seed = [5, 9]
>>> idx0 = tf.random.experimental.index_shuffle(0, seed, 3)
>>> idx1 = tf.random.experimental.index_shuffle(1, seed, 3)
>>> idx2 = tf.random.experimental.index_shuffle(2, seed, 3)
>>> idx3 = tf.random.experimental.index_shuffle(3, seed, 3)
>>> shuffled_vector = tf.gather(vector, [idx0, idx1, idx2, idx3])
>>> print(shuffled_vector)
tf.Tensor([b'e2' b'e0' b'e1' b'e3'], shape=(4,), dtype=string)
Args:
index: An integer scalar tensor or vector with values in `[0, max_index]`.
It can be seen as either a value `v` in the sequence `S=[0, ...,
max_index]` to be permutated, or as an index of an element `e` in a
shuffled vector.
seed: A tensor of shape [2] or [n, 2] with dtype `int32`, `uint32`, `int64`
or `uint64`. The RNG seed. If the rank is unknown during graph-building
time it must be 1 at runtime.
max_index: A non-negative tensor with the same shape and dtype as `index`.
The upper bound (inclusive).
Returns:
If all inputs were scalar (shape [2] for `seed`), the output will
be a scalar with the same dtype as `index`. The output can be seen
as the new position of `v` in `S`, or as the index of `e` in the
vector before shuffling. If one or multiple inputs were vectors
(shape [n, 2] for `seed`), then the output will be a vector of the
same size which each element shuffled independently. Scalar values
are broadcasted in this case.
"""
# We expect users to pass a seed with shape [2] to be consistent with other
# stateless_* ops, but the raw op expects shape [3].
seed = ops.convert_to_tensor(seed)
# Pad the first dimension with an arbitrary number since our raw op expects
# shape [3].
if seed.shape.rank is None:
paddings = [[1, 0]]
else:
paddings = [[1, 0]] + (seed.shape.rank - 1) * [[0, 0]]
seed = array_ops.pad(seed, paddings, constant_values=498247692)
return gen_random_index_shuffle_ops.random_index_shuffle(
index, seed=seed, max_index=max_index, rounds=4
)
@tf_export("random.experimental.stateless_shuffle")
@dispatch.add_dispatch_support
def stateless_shuffle(value, seed, alg="auto_select", name=None):
"""Randomly and deterministically shuffles a tensor along its first dimension.
The tensor is shuffled along dimension 0, such that each `value[j]` is mapped
to one and only one `output[i]`. For example, a mapping that might occur for a
3x2 tensor is:
```python
[[1, 2], [[5, 6],
[3, 4], ==> [1, 2],
[5, 6]] [3, 4]]
```
>>> v = tf.constant([[1, 2], [3, 4], [5, 6]])
>>> shuffled = tf.random.experimental.stateless_shuffle(v, seed=[8, 9])
>>> print(shuffled)
tf.Tensor(
[[5 6]
[1 2]
[3 4]], shape=(3, 2), dtype=int32)
This is a stateless version of `tf.random.shuffle`: if run twice with the
same `value` and `seed`, it will produce the same result. The
output is consistent across multiple runs on the same hardware (and between
CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
hardware.
Args:
value: A Tensor to be shuffled.
seed: A shape [2] Tensor. The seed to the random number generator. Must have
dtype `int32` or `int64`.
alg: The RNG algorithm used to generate the random numbers. See
`tf.random.stateless_uniform` for a detailed explanation.
name: A name for the operation.
Returns:
A tensor of same shape and type as `value`, shuffled along its first
dimension.
"""
with ops.name_scope(name, "stateless_shuffle", [value, seed]) as name:
key, counter, alg = random_ops_util.get_key_counter_alg(seed, alg)
return gen_stateless_random_ops_v2.stateless_shuffle(
value, key=key, counter=counter, alg=alg
)
@tf_export("random.stateless_uniform")
@dispatch.add_dispatch_support
def stateless_random_uniform(
shape,
seed,
minval=0,
maxval=None,
dtype=dtypes.float32,
name=None,
alg="auto_select",
):
"""Outputs deterministic pseudorandom values from a uniform distribution.
This is a stateless version of `tf.random.uniform`: if run twice with the
same seeds and shapes, it will produce the same pseudorandom numbers. The
output is consistent across multiple runs on the same hardware (and between
CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
hardware.
The generated values follow a uniform distribution in the range
`[minval, maxval)`. The lower bound `minval` is included in the range, while
the upper bound `maxval` is excluded.
For floats, the default range is `[0, 1)`. For ints, at least `maxval` must
be specified explicitly.
In the integer case, the random integers are slightly biased unless
`maxval - minval` is an exact power of two. The bias is small for values of
`maxval - minval` significantly smaller than the range of the output (either
`2**32` or `2**64`).
For full-range (i.e. inclusive of both max and min) random integers, pass
`minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype
either both `minval` and `maxval` must be `None` or neither may be `None`. For
example:
```python
ints = tf.random.stateless_uniform(
[10], seed=(2, 3), minval=None, maxval=None, dtype=tf.int32)
```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
minval: A Tensor or Python value of type `dtype`, broadcastable with `shape`
(for integer types, broadcasting is not supported, so it needs to be a
scalar). The lower bound on the range of random values to generate. Pass
`None` for full-range integers. Defaults to 0.
maxval: A Tensor or Python value of type `dtype`, broadcastable with `shape`
(for integer types, broadcasting is not supported, so it needs to be a
scalar). The upper bound on the range of random values to generate.
Defaults to 1 if `dtype` is floating point. Pass `None` for full-range
integers.
dtype: The type of the output: `float16`, `bfloat16`, `float32`, `float64`,
`int32`, or `int64`. For unbounded uniform ints (`minval`, `maxval` both
`None`), `uint32` and `uint64` may be used. Defaults to `float32`.
name: A name for the operation (optional).
alg: The RNG algorithm used to generate the random numbers. Valid choices
are `"philox"` for [the Philox
algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf),
`"threefry"` for [the ThreeFry
algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf),
and `"auto_select"` (default) for the system to automatically select an
algorithm based the device type. Values of `tf.random.Algorithm` can also
be used. Note that with `"auto_select"`, the outputs of this function may
change when it is running on a different device.
Returns:
A tensor of the specified shape filled with random uniform values.
Raises:
ValueError: If `dtype` is integral and only one of `minval` or `maxval` is
specified.
"""
dtype = dtypes.as_dtype(dtype)
accepted_dtypes = (
dtypes.float16,
dtypes.bfloat16,
dtypes.float32,
dtypes.float64,
dtypes.int32,
dtypes.int64,
dtypes.uint32,
dtypes.uint64,
)
if dtype not in accepted_dtypes:
raise ValueError(
f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are "
f"{accepted_dtypes}."
)
if dtype.is_integer:
if (minval is None) != (maxval is None):
raise ValueError(
f"For integer `dtype` argument {dtype}, argument `minval` and "
f"`maxval` must be both None or not None. Got `minval`={minval} and "
f"`maxval`={maxval}."
)
if minval is not None and dtype in (dtypes.uint32, dtypes.uint64):
raise ValueError(
f"Argument `dtype` got invalid value {dtype} when argument `minval` "
"is not None. Please don't use unsigned integers in this case."
)
elif maxval is None:
maxval = 1
with ops.name_scope(
name, "stateless_random_uniform", [shape, seed, minval, maxval]
) as name:
shape = shape_util.shape_tensor(shape)
if dtype.is_integer and minval is None:
key, counter, alg = random_ops_util.get_key_counter_alg(seed, alg)
result = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg, name=name
)
else:
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
if dtype.is_integer:
key, counter, alg = random_ops_util.get_key_counter_alg(seed, alg)
result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
shape,
key=key,
counter=counter,
minval=minval,
maxval=maxval,
alg=alg,
name=name,
)
else:
key, counter, alg = random_ops_util.get_key_counter_alg(seed, alg)
rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg
)
result = math_ops.add(rnd * (maxval - minval), minval, name=name)
shape_util.maybe_set_static_shape(result, shape)
return result
@tf_export("random.stateless_binomial")
@dispatch.add_dispatch_support
def stateless_random_binomial(
shape, seed, counts, probs, output_dtype=dtypes.int32, name=None
):
"""Outputs deterministic pseudorandom values from a binomial distribution.
The generated values follow a binomial distribution with specified count and
probability of success parameters.
This is a stateless version of `tf.random.Generator.binomial`: if run twice
with the same seeds and shapes, it will produce the same pseudorandom numbers.
The output is consistent across multiple runs on the same hardware (and
between CPU and GPU), but may change between versions of TensorFlow or on
non-CPU/GPU hardware.
Example:
```python
counts = [10., 20.]
# Probability of success.
probs = [0.8]
binomial_samples = tf.random.stateless_binomial(
shape=[2], seed=[123, 456], counts=counts, probs=probs)
counts = ... # Shape [3, 1, 2]
probs = ... # Shape [1, 4, 2]
shape = [3, 4, 3, 4, 2]
# Sample shape will be [3, 4, 3, 4, 2]
binomial_samples = tf.random.stateless_binomial(
shape=shape, seed=[123, 456], counts=counts, probs=probs)
```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
counts: Tensor. The counts of the binomial distribution. Must be
broadcastable with `probs`, and broadcastable with the rightmost
dimensions of `shape`.
probs: Tensor. The probability of success for the binomial distribution.
Must be broadcastable with `counts` and broadcastable with the rightmost
dimensions of `shape`.
output_dtype: The type of the output. Default: tf.int32
name: A name for the operation (optional).
Returns:
samples: A Tensor of the specified shape filled with random binomial
values. For each i, each samples[..., i] is an independent draw from
the binomial distribution on counts[i] trials with probability of
success probs[i].
"""
with ops.name_scope(
name, "stateless_random_binomial", [shape, seed, counts, probs]
) as name:
shape = shape_util.shape_tensor(shape)
probs = ops.convert_to_tensor(
probs, dtype_hint=dtypes.float32, name="probs"
)
counts = ops.convert_to_tensor(
counts, dtype_hint=probs.dtype, name="counts"
)
result = gen_stateless_random_ops.stateless_random_binomial(
shape=shape, seed=seed, counts=counts, probs=probs, dtype=output_dtype
)
shape_util.maybe_set_static_shape(result, shape)
return result
@tf_export("random.stateless_gamma")
@dispatch.add_dispatch_support
def stateless_random_gamma(
shape, seed, alpha, beta=None, dtype=dtypes.float32, name=None
):
"""Outputs deterministic pseudorandom values from a gamma distribution.
The generated values follow a gamma distribution with specified concentration
(`alpha`) and inverse scale (`beta`) parameters.
This is a stateless version of `tf.random.gamma`: if run twice with the same
seeds and shapes, it will produce the same pseudorandom numbers. The output is
consistent across multiple runs on the same hardware (and between CPU and
GPU),
but may change between versions of TensorFlow or on non-CPU/GPU hardware.
A slight difference exists in the interpretation of the `shape` parameter
between `stateless_gamma` and `gamma`: in `gamma`, the `shape` is always
prepended to the shape of the broadcast of `alpha` with `beta`; whereas in
`stateless_gamma` the `shape` parameter must always encompass the shapes of
each of `alpha` and `beta` (which must broadcast together to match the
trailing dimensions of `shape`).
Note: Because internal calculations are done using `float64` and casting has
`floor` semantics, we must manually map zero outcomes to the smallest
possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This
means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise
should. This bias can only happen for small values of `alpha`, i.e.,
`alpha << 1` or large values of `beta`, i.e., `beta >> 1`.
The samples are differentiable w.r.t. alpha and beta.
The derivatives are computed using the approach described in
(Figurnov et al., 2018).
Example:
```python
samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5])
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
# the samples drawn from each distribution
samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5])
# samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
# represents the 7x5 samples drawn from each of the two distributions
alpha = tf.constant([[1.], [3.], [5.]])
beta = tf.constant([[3., 4.]])
samples = tf.random.stateless_gamma(
[30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)
# samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions.
with tf.GradientTape() as tape:
tape.watch([alpha, beta])
loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma(
[30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)))
dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta])
# unbiased stochastic derivatives of the loss function
alpha.shape == dloss_dalpha.shape # True
beta.shape == dloss_dbeta.shape # True
```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
alpha: Tensor. The concentration parameter of the gamma distribution. Must
be broadcastable with `beta`, and broadcastable with the rightmost
dimensions of `shape`.
beta: Tensor. The inverse scale parameter of the gamma distribution. Must be
broadcastable with `alpha` and broadcastable with the rightmost dimensions
of `shape`.
dtype: Floating point dtype of `alpha`, `beta`, and the output.
name: A name for the operation (optional).
Returns:
samples: A Tensor of the specified shape filled with random gamma values.
For each i, each `samples[..., i] is an independent draw from the gamma
distribution with concentration alpha[i] and scale beta[i].
"""
with ops.name_scope(
name, "stateless_random_gamma", [shape, seed, alpha, beta]
) as name:
shape = shape_util.shape_tensor(shape)
alpha = ops.convert_to_tensor(alpha, dtype=dtype, name="alpha")
beta = ops.convert_to_tensor(
beta if beta is not None else 1, name="beta", dtype=dtype
)
broadcast_shape = array_ops.broadcast_dynamic_shape(
array_ops.shape(alpha), array_ops.shape(beta)
)
alpha_broadcast = array_ops.broadcast_to(alpha, broadcast_shape)
alg = "auto_select"
key, counter, alg = random_ops_util.get_key_counter_alg(seed, alg)
rnd = gen_stateless_random_ops_v2.stateless_random_gamma_v3(
shape, key=key, counter=counter, alg=alg, alpha=alpha_broadcast
)
result = math_ops.maximum(
np.finfo(alpha.dtype.as_numpy_dtype).tiny, rnd / beta
)
shape_util.maybe_set_static_shape(result, shape)
return result
@tf_export("random.stateless_poisson")
@dispatch.add_dispatch_support
def stateless_random_poisson(shape, seed, lam, dtype=dtypes.int32, name=None):
"""Outputs deterministic pseudorandom values from a Poisson distribution.
The generated values follow a Poisson distribution with specified rate
parameter.
This is a stateless version of `tf.random.poisson`: if run twice with the same
seeds and shapes, it will produce the same pseudorandom numbers. The output is
consistent across multiple runs on the same hardware, but may change between
versions of TensorFlow or on non-CPU/GPU hardware.
A slight difference exists in the interpretation of the `shape` parameter
between `stateless_poisson` and `poisson`: in `poisson`, the `shape` is always
prepended to the shape of `lam`; whereas in `stateless_poisson` the shape of
`lam` must match the trailing dimensions of `shape`.
Example:
```python
samples = tf.random.stateless_poisson([10, 2], seed=[12, 34], lam=[5, 15])
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
# the samples drawn from each distribution
samples = tf.random.stateless_poisson([7, 5, 2], seed=[12, 34], lam=[5, 15])
# samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
# represents the 7x5 samples drawn from each of the two distributions
rate = tf.constant([[1.], [3.], [5.]])
samples = tf.random.stateless_poisson([30, 3, 1], seed=[12, 34], lam=rate)
# samples has shape [30, 3, 1], with 30 samples each of 3x1 distributions.
```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape
must match the rightmost dimensions of `shape`.
dtype: Dtype of the samples (int or float dtypes are permissible, as samples
are discrete). Default: int32.
name: A name for the operation (optional).
Returns:
samples: A Tensor of the specified shape filled with random Poisson values.
For each i, each `samples[..., i]` is an independent draw from the Poisson
distribution with rate `lam[i]`.
"""
with ops.name_scope(
name, "stateless_random_poisson", [shape, seed, lam]
) as name:
shape = shape_util.shape_tensor(shape)
result = gen_stateless_random_ops.stateless_random_poisson(
shape, seed=seed, lam=lam, dtype=dtype
)
shape_util.maybe_set_static_shape(result, shape)
return result
@tf_export("random.stateless_normal")
@dispatch.add_dispatch_support
def stateless_random_normal(
shape,
seed,
mean=0.0,
stddev=1.0,
dtype=dtypes.float32,
name=None,
alg="auto_select",
):
"""Outputs deterministic pseudorandom values from a normal distribution.
This is a stateless version of `tf.random.normal`: if run twice with the
same seeds and shapes, it will produce the same pseudorandom numbers. The
output is consistent across multiple runs on the same hardware (and between
CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
hardware.
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
distribution.
stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
of the normal distribution.
dtype: The float type of the output: `float16`, `bfloat16`, `float32`,
`float64`. Defaults to `float32`.
name: A name for the operation (optional).
alg: The RNG algorithm used to generate the random numbers. See
`tf.random.stateless_uniform` for a detailed explanation.
Returns:
A tensor of the specified shape filled with random normal values.
"""
with ops.name_scope(
name, "stateless_random_normal", [shape, seed, mean, stddev]
) as name:
shape = shape_util.shape_tensor(shape)
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
key, counter, alg = random_ops_util.get_key_counter_alg(seed, alg)
rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg
)
result = math_ops.add(rnd * stddev, mean, name=name)
shape_util.maybe_set_static_shape(result, shape)
return result
@tf_export("random.stateless_truncated_normal")
@dispatch.add_dispatch_support
def stateless_truncated_normal(
shape,
seed,
mean=0.0,
stddev=1.0,
dtype=dtypes.float32,
name=None,
alg="auto_select",
):
"""Outputs deterministic pseudorandom values, truncated normally distributed.
This is a stateless version of `tf.random.truncated_normal`: if run twice with
the same seeds and shapes, it will produce the same pseudorandom numbers. The
output is consistent across multiple runs on the same hardware (and between
CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
hardware.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 standard
deviations from the mean are dropped and re-picked.
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
truncated normal distribution.
stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
of the normal distribution, before truncation.
dtype: The type of the output.
name: A name for the operation (optional).
alg: The RNG algorithm used to generate the random numbers. See
`tf.random.stateless_uniform` for a detailed explanation.
Returns:
A tensor of the specified shape filled with random truncated normal values.
"""
with ops.name_scope(
name, "stateless_truncated_normal", [shape, seed, mean, stddev]
) as name:
shape = shape_util.shape_tensor(shape)
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
key, counter, alg = random_ops_util.get_key_counter_alg(seed, alg)
rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg
)
result = math_ops.add(rnd * stddev, mean, name=name)
shape_util.maybe_set_static_shape(result, shape)
return result
@tf_export(v1=["random.stateless_multinomial"])
@dispatch.add_dispatch_support
@deprecation.deprecated(
date=None, instructions="Use `tf.random.stateless_categorical` instead."
)
def stateless_multinomial(
logits, num_samples, seed, output_dtype=dtypes.int64, name=None
):
"""Draws deterministic pseudorandom samples from a multinomial distribution.
This is a stateless version of `tf.random.categorical`: if run twice with the
same seeds and shapes, it will produce the same pseudorandom numbers. The
output is consistent across multiple runs on the same hardware (and between
CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
hardware.
Example:
```python
# samples has shape [1, 5], where each value is either 0 or 1 with equal
# probability.
samples = tf.random.stateless_categorical(
tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
```
Args:
logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i,
:]` represents the unnormalized log-probabilities for all classes.
num_samples: 0-D. Number of independent samples to draw for each row slice.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
output_dtype: The integer type of the output: `int32` or `int64`. Defaults
to `int64`.
name: Optional name for the operation.
Returns:
The drawn samples of shape `[batch_size, num_samples]`.
"""
with ops.name_scope(name, "stateless_multinomial", [logits, seed]):
return stateless_multinomial_categorical_impl(
logits, num_samples, output_dtype, seed
)
@tf_export("random.stateless_categorical")
@dispatch.add_dispatch_support
def stateless_categorical(
logits, num_samples, seed, dtype=dtypes.int64, name=None
):
"""Draws deterministic pseudorandom samples from a categorical distribution.
This is a stateless version of `tf.categorical`: if run twice with the
same seeds and shapes, it will produce the same pseudorandom numbers. The
output is consistent across multiple runs on the same hardware (and between
CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
hardware.
Example:
```python
# samples has shape [1, 5], where each value is either 0 or 1 with equal
# probability.
samples = tf.random.stateless_categorical(
tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
```
Args:
logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i,
:]` represents the unnormalized log-probabilities for all classes.
num_samples: 0-D. Number of independent samples to draw for each row slice.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
dtype: The integer type of the output: `int32` or `int64`. Defaults to
`int64`.
name: Optional name for the operation.
Returns:
The drawn samples of shape `[batch_size, num_samples]`.
"""
with ops.name_scope(name, "stateless_categorical", [logits, seed]):
return stateless_multinomial_categorical_impl(
logits, num_samples, dtype, seed
)
def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed):
"""Implementation for stateless multinomial/categorical ops (v1/v2)."""
logits = ops.convert_to_tensor(logits, name="logits")
dtype = dtypes.as_dtype(dtype) if dtype else dtypes.int64
accepted_dtypes = (dtypes.int32, dtypes.int64)
if dtype not in accepted_dtypes:
raise ValueError(
f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are "
f"{accepted_dtypes}."
)
return gen_stateless_random_ops.stateless_multinomial(
logits, num_samples, seed, output_dtype=dtype
)
@dispatch.add_dispatch_support
@tf_export("random.stateless_parameterized_truncated_normal")
def stateless_parameterized_truncated_normal(
shape, seed, means=0.0, stddevs=1.0, minvals=-2.0, maxvals=2.0, name=None
):
"""Outputs random values from a truncated normal distribution.
The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 standard
deviations from the mean are dropped and re-picked.
Examples:
Sample from a Truncated normal, with deferring shape parameters that
broadcast.
>>> means = 0.
>>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3]))
>>> minvals = [-1., -2., -1000.]
>>> maxvals = [[10000.], [1.]]
>>> y = tf.random.stateless_parameterized_truncated_normal(
... shape=[10, 2, 3], seed=[7, 17],
... means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals)
>>> y.shape
TensorShape([10, 2, 3])
Args:
shape: A 1-D integer `Tensor` or Python array. The shape of the output
tensor.
seed: A shape [2] Tensor, the seed to the random number generator. Must have
dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
means: A `Tensor` or Python value of type `dtype`. The mean of the truncated
normal distribution. This must broadcast with `stddevs`, `minvals` and
`maxvals`, and the broadcasted shape must be dominated by `shape`.
stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation
of the truncated normal distribution. This must broadcast with `means`,
`minvals` and `maxvals`, and the broadcasted shape must be dominated by
`shape`.
minvals: A `Tensor` or Python value of type `dtype`. The minimum value of
the truncated normal distribution. This must broadcast with `means`,
`stddevs` and `maxvals`, and the broadcasted shape must be dominated by
`shape`.
maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of
the truncated normal distribution. This must broadcast with `means`,
`stddevs` and `minvals`, and the broadcasted shape must be dominated by
`shape`.
name: A name for the operation (optional).
Returns:
A tensor of the specified shape filled with random truncated normal values.
"""
with ops.name_scope(
name,
"stateless_parameterized_truncated_normal",
[shape, means, stddevs, minvals, maxvals],
) as name:
shape_tensor = shape_util.shape_tensor(shape)
means_tensor = ops.convert_to_tensor(means, name="means")
stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs")
minvals_tensor = ops.convert_to_tensor(minvals, name="minvals")
maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals")
rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal(
shape_tensor,
seed,
means_tensor,
stddevs_tensor,
minvals_tensor,
maxvals_tensor,
)
shape_util.maybe_set_static_shape(rnd, shape)
return rnd