/
init_ops.py
1832 lines (1502 loc) · 65.2 KB
/
init_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations often used for initializing tensors.
All variable initializers returned by functions in this file should have the
following signature:
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
Args:
shape: List of `int` representing the shape of the output `Tensor`. Some
initializers may also be able to accept a `Tensor`.
dtype: (Optional) Type of the output `Tensor`.
partition_info: (Optional) variable_scope._PartitionInfo object holding
additional information about how the variable is partitioned. May be
`None` if the variable is not partitioned.
Returns:
A `Tensor` of type `dtype` and `shape`.
"""
import math
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.deprecation import deprecated_arg_values
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
class Initializer(object):
"""Initializer base class: all initializers inherit from this class."""
def __call__(self, shape, dtype=None, partition_info=None):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. If not provided use the initializer
dtype.
partition_info: Optional information about the possible partitioning of a
tensor.
"""
raise NotImplementedError
def get_config(self):
"""Returns the configuration of the initializer as a JSON-serializable dict.
Returns:
A JSON-serializable Python dict.
"""
return {}
@classmethod
def from_config(cls, config):
"""Instantiates an initializer from a configuration dictionary.
Example:
```python
initializer = RandomUniform(-1, 1)
config = initializer.get_config()
initializer = RandomUniform.from_config(config)
```
Args:
config: A Python dictionary. It will typically be the output of
`get_config`.
Returns:
An Initializer instance.
"""
return cls(**config)
@tf_export(v1=["initializers.zeros", "zeros_initializer"])
@deprecation.deprecated_endpoints("initializers.zeros")
class Zeros(Initializer):
"""Initializer that generates tensors initialized to 0.
@compatibility(TF2)
`tf.compat.v1.zeros_initializer` is compatible with eager execution
and `tf.function`.
To migrate to TF2, please use `tf.zeros_initializer` instead. The `dtype`
argument in `tf.compat.v1.zeros_initializer.__init__()` does not exist in
`tf.zeros_initializer.__init__()`. However, you can specify the `dtype` in
`__call__()` in both cases.
#### Structural Mapping to TF2
Before:
```python
initializer = tf.compat.v1.zeros_initializer(dtype=tf.float32)
variable = tf.Variable(initializer(shape=[3, 3]))
```
After:
```python
initializer = tf.zeros_initializer()
variable = tf.Variable(initializer(shape=[3, 3], dtype=tf.float32))
```
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :------------------- | :--------------- | :------------------------- |
| `dtype` | `dtype` | In `__call__()` method |
| `partition_info` | - | (`__call__` arg in TF1) Not supported |
#### Before & After Usage Example
Before:
>>> initializer = tf.compat.v1.zeros_initializer(dtype=tf.float32)
>>> tf.Variable(initializer(shape=[3])).numpy()
array([0., 0., 0.], dtype=float32)
>>> tf.Variable(initializer(shape=[3, 3])).numpy()
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
>>> initializer = tf.compat.v1.zeros_initializer()
>>> tf.Variable(initializer(shape=[3], dtype=tf.float32)).numpy()
array([0., 0., 0.], dtype=float32)
>>> tf.Variable(initializer(shape=[3, 3], dtype=tf.float32)).numpy()
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
After:
>>> initializer = tf.zeros_initializer()
>>> tf.Variable(initializer(shape=[3], dtype=tf.float32)).numpy()
array([0., 0., 0.], dtype=float32)
>>> tf.Variable(initializer(shape=[3, 3], dtype=tf.float32)).numpy()
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
@end_compatibility
"""
@deprecated_args(None,
"Call initializer instance with the dtype argument instead "
"of passing it to the constructor", "dtype")
def __init__(self, dtype=dtypes.float32):
self.dtype = dtypes.as_dtype(dtype)
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return array_ops.zeros(shape, dtype)
def get_config(self):
return {"dtype": self.dtype.name}
@tf_export(v1=["initializers.ones", "ones_initializer"])
@deprecation.deprecated_endpoints("initializers.ones", "ones_initializer")
class Ones(Initializer):
"""Initializer that generates tensors initialized to 1.
@compatibility(TF2)
This API is compatible with TF2 behavior and `tf.function`, and can be
migrated immediately with `tf.keras.initializers.ones`.
Before:
>>> initializer = tf.compat.v1.keras.initializers.ones()
>>> initializer((1, 1))
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>
After:
>>> initializer = tf.keras.initializers.ones()
>>> initializer((1, 1))
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>
@end_compatibility
"""
@deprecated_args(None,
"Call initializer instance with the dtype argument instead "
"of passing it to the constructor", "dtype")
def __init__(self, dtype=dtypes.float32):
self.dtype = dtypes.as_dtype(dtype)
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return array_ops.ones(shape, dtype)
def get_config(self):
return {"dtype": self.dtype.name}
@tf_export(v1=["initializers.constant", "constant_initializer"])
@deprecation.deprecated_endpoints("constant_initializer")
class Constant(Initializer):
"""Initializer that generates tensors with constant values.
The resulting tensor is populated with values of type `dtype`, as
specified by arguments `value` following the desired `shape` of the
new tensor (see examples below).
The argument `value` can be a constant value, or a list of values of type
`dtype`. If `value` is a list, then the length of the list must be less
than or equal to the number of elements implied by the desired shape of the
tensor. In the case where the total number of elements in `value` is less
than the number of elements required by the tensor shape, the last element
in `value` will be used to fill the remaining entries. If the total number of
elements in `value` is greater than the number of elements required by the
tensor shape, the initializer will raise a `ValueError`.
Args:
value: A Python scalar, list or tuple of values, or a N-dimensional numpy
array. All elements of the initialized variable will be set to the
corresponding value in the `value` argument.
dtype: Default data type, used if no `dtype` argument is provided when
calling the initializer.
verify_shape: Boolean that enables verification of the shape of `value`. If
`True`, the initializer will throw an error if the shape of `value` is not
compatible with the shape of the initialized tensor.
Raises:
TypeError: If the input `value` is not one of the expected types.
Examples:
The following example can be rewritten using a numpy.ndarray instead
of the `value` list, even reshaped, as shown in the two commented lines
below the `value` list initialization.
>>> value = [0, 1, 2, 3, 4, 5, 6, 7]
>>> init = tf.compat.v1.constant_initializer(value)
>>> # fitting shape
>>> with tf.compat.v1.Session():
... x = tf.compat.v1.get_variable('x', shape=[2, 4], initializer=init)
... x.initializer.run()
... print(x.eval())
[[0. 1. 2. 3.]
[4. 5. 6. 7.]]
>>> # Larger shape
>>> with tf.compat.v1.Session():
... y = tf.compat.v1.get_variable('y', shape=[3, 4], initializer=init)
... y.initializer.run()
... print(y.eval())
[[0. 1. 2. 3.]
[4. 5. 6. 7.]
[7. 7. 7. 7.]]
>>> # Smaller shape
>>> with tf.compat.v1.Session():
... z = tf.compat.v1.get_variable('z', shape=[2, 3], initializer=init)
Traceback (most recent call last):
...
ValueError: Too many elements provided. Needed at most 6, but received 8
>>> # Shape verification
>>> init_verify = tf.compat.v1.constant_initializer(value, verify_shape=True)
>>> with tf.compat.v1.Session():
... u = tf.compat.v1.get_variable('u', shape=[3, 4],
... initializer=init_verify)
Traceback (most recent call last):
...
TypeError: Expected Tensor's shape: (3, 4), got (8,).
@compatibility(TF2)
Although it is a legacy API endpoint, `tf.compat.v1.constant_initializer`
is compatible with eager execution and `tf.function`.
To migrate to a non-legacy TF2 API, please use `tf.constant_initializer`
instead. The `dtype`
argument in `tf.compat.v1.constant_initializer.__init__()` does not exist in
`tf.constant_initializer.__init__()`. However, you can specify the `dtype` in
`__call__()` in both cases.
In the `compat.v1` symbol, if `verify_shape` is set to `True`, an exception
is raised when initializing a variable with a different shape from
`value`. If set to `False`, `value` is reshaped to initialize the variable
if necessary. An exception would only be raised when the number of
elements are different.
The `verify_shape` argument is not supported in TF2. Using
`tf.constant_initializer` is equivalent to setting `verify_shape` to `False`.
#### Structural Mapping to TF2
Before:
```python
value = [0, 1, 2, 3, 4, 5, 6, 7]
initializer = tf.compat.v1.constant_initializer(
value=value,
dtype=tf.float32,
verify_shape=False)
variable = tf.Variable(initializer(shape=[2, 4]))
```
After:
```python
value = [0, 1, 2, 3, 4, 5, 6, 7]
initializer = tf.constant_initializer(value=value)
tf.Variable(initializer(shape=[2, 4], dtype=tf.float32))
```
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :-------------------- | :--------------- | :-------------------------- |
| `value` | `value` | In constructor |
| `dtype` | `dtype` | In `__call__()` method |
| `verify_shape` | Not Supported | Equivalent to set to `False`|
| `partition_info` | - | (`__call__` arg in TF1) Not supported |
#### Before & After Usage Example
Before:
>>> value = [1., 2., 3., 4.]
>>> initializer = tf.compat.v1.constant_initializer(
... value=value, dtype=tf.float32, verify_shape=True)
>>> tf.Variable(initializer(shape=[2, 2])).numpy()
Traceback (most recent call last):
...
TypeError: Expected Tensor's shape: (2, 2), got (4,).
>>> initializer = tf.compat.v1.constant_initializer(
... value=value, dtype=tf.float32, verify_shape=False)
>>> tf.Variable(initializer(shape=[2, 2])).numpy()
array([[1., 2.],
[3., 4.]], dtype=float32)
After:
>>> value = [1., 2., 3., 4.]
>>> initializer = tf.constant_initializer(value=value)
>>> tf.Variable(initializer(shape=[2, 2], dtype=tf.float32)).numpy()
array([[1., 2.],
[3., 4.]], dtype=float32)
@end_compatibility
"""
@deprecated_args(None,
"Call initializer instance with the dtype argument instead "
"of passing it to the constructor", "dtype")
@deprecated_args(None, "Objects must now be the required shape or no shape "
"can be specified", "verify_shape")
def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
raise TypeError(
f"Invalid type for initial value={value} of type: "
f"{type(value).__name__}. Expected Python scalar, list or tuple of "
"values, or numpy.ndarray.")
self.value = value
self.dtype = dtypes.as_dtype(dtype)
self._verify_shape = verify_shape
def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
if dtype is None:
dtype = self.dtype
if verify_shape is None:
verify_shape = self._verify_shape
return constant_op.constant_v1(
self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
def get_config(self):
# We don't include `verify_shape` for compatibility with Keras.
# `verify_shape` should be passed as an argument to `__call__` rather
# than as a constructor argument: conceptually it isn't a property
# of the initializer.
return {"value": self.value, "dtype": self.dtype.name}
@tf_export(v1=["initializers.random_uniform", "random_uniform_initializer"])
@deprecation.deprecated_endpoints("initializers.random_uniform")
class RandomUniform(Initializer):
"""Initializer that generates tensors with a uniform distribution.
Args:
minval: A python scalar or a scalar tensor. Lower bound of the range of
random values to generate.
maxval: A python scalar or a scalar tensor. Upper bound of the range of
random values to generate. Defaults to 1 for float types.
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: Default data type, used if no `dtype` argument is provided when
calling the initializer.
@compatibility(TF2)
Although it is a legacy compat.v1 API, this symbol is compatible with eager
execution and `tf.function`.
To switch to TF2, switch to using either
`tf.initializers.RandomUniform` or `tf.keras.initializers.RandomUniform`
(neither from `compat.v1`) and
pass the dtype when calling the initializer. Keep in mind that
the default minval, maxval and the behavior of fixed seeds have changed.
#### Structural Mapping to TF2
Before:
```python
initializer = tf.compat.v1.random_uniform_initializer(
minval=minval,
maxval=maxval,
seed=seed,
dtype=dtype)
weight_one = tf.Variable(initializer(shape_one))
weight_two = tf.Variable(initializer(shape_two))
```
After:
```python
initializer = tf.initializers.RandomUniform(
minval=minval,
maxval=maxval,
seed=seed)
weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
```
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :-------------------- | :-------------- | :------------------------- |
| `minval` | `minval` | Default changes from 0 to -0.05 |
| `maxval` | `maxval` | Default changes from 1.0 to 0.05 |
| `seed` | `seed` | |
| `dtype` | `dtype` | The TF2 native api only takes it |
: : : as a `__call__` arg, not a constructor arg. :
| `partition_info` | - | (`__call__` arg in TF1) Not supported |
@end_compatibility
"""
@deprecated_args(None,
"Call initializer instance with the dtype argument instead "
"of passing it to the constructor", "dtype")
def __init__(self, minval=.0, maxval=None, seed=None, dtype=dtypes.float32):
self.minval = minval
self.maxval = maxval
self.seed = seed
self.dtype = dtypes.as_dtype(dtype)
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return random_ops.random_uniform(
shape, self.minval, self.maxval, dtype, seed=self.seed)
def get_config(self):
return {
"minval": self.minval,
"maxval": self.maxval,
"seed": self.seed,
"dtype": self.dtype.name
}
@tf_export(v1=["initializers.random_normal", "random_normal_initializer"])
@deprecation.deprecated_endpoints("initializers.random_normal")
class RandomNormal(Initializer):
"""Initializer that generates tensors with a normal distribution.
Args:
mean: a python scalar or a scalar tensor. Mean of the random values to
generate.
stddev: a python scalar or a scalar tensor. Standard deviation of the random
values to generate.
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: Default data type, used if no `dtype` argument is provided when
calling the initializer. Only floating point types are supported.
@compatibility(TF2)
Although it is a legacy `compat.v1` API, this symbol is compatible with eager
execution and `tf.function`.
To switch to TF2, switch to using either
`tf.initializers.RandomNormal` or `tf.keras.initializers.RandomNormal`
(neither from `compat.v1`) and
pass the dtype when calling the initializer. Keep in mind that
the default stddev and the behavior of fixed seeds have changed.
#### Structural Mapping to TF2
Before:
```python
initializer = tf.compat.v1.random_normal_initializer(
mean=mean,
stddev=stddev,
seed=seed,
dtype=dtype)
weight_one = tf.Variable(initializer(shape_one))
weight_two = tf.Variable(initializer(shape_two))
```
After:
```python
initializer = tf.initializers.RandomNormal(
mean=mean,
seed=seed,
stddev=stddev)
weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
```
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :----------------- | :-------------- | :------------------------- |
| `mean` | `mean` | No change to defaults |
| `stddev` | `stddev` | Default changes from 1.0 to 0.05 |
| `seed` | `seed` | |
| `dtype` | `dtype` | The TF2 native api only takes it as a |
: : : `__call__` arg, not a constructor arg. :
| `partition_info` | - | (`__call__` arg in TF1) Not supported. |
@end_compatibility
"""
@deprecated_args(None,
"Call initializer instance with the dtype argument instead "
"of passing it to the constructor", "dtype")
def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
self.mean = mean
self.stddev = stddev
self.seed = seed
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return random_ops.random_normal(
shape, self.mean, self.stddev, dtype, seed=self.seed)
def get_config(self):
return {
"mean": self.mean,
"stddev": self.stddev,
"seed": self.seed,
"dtype": self.dtype.name
}
@tf_export(v1=["initializers.truncated_normal", "truncated_normal_initializer"])
@deprecation.deprecated_endpoints("initializers.truncated_normal",
"truncated_normal_initializer")
class TruncatedNormal(Initializer):
"""Initializer that generates a truncated normal distribution.
These values are similar to values from a `random_normal_initializer`
except that values more than two standard deviations from the mean
are discarded and re-drawn. This is the recommended initializer for
neural network weights and filters.
Args:
mean: a python scalar or a scalar tensor. Mean of the random values to
generate.
stddev: a python scalar or a scalar tensor. Standard deviation of the random
values to generate.
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: Default data type, used if no `dtype` argument is provided when
calling the initializer. Only floating point types are supported.
@compatibility(TF2)
Although it is a legacy `compat.v1` API, this symbol is compatible with eager
execution and `tf.function`.
To switch to TF2, switch to using either
`tf.initializers.truncated_normal` or `tf.keras.initializers.TruncatedNormal`
(neither from `compat.v1`) and
pass the dtype when calling the initializer. Keep in mind that
the default stddev and the behavior of fixed seeds have changed.
#### Structural Mapping to TF2
Before:
```python
initializer = tf.compat.v1.truncated_normal_initializer(
mean=mean,
stddev=stddev,
seed=seed,
dtype=dtype)
weight_one = tf.Variable(initializer(shape_one))
weight_two = tf.Variable(initializer(shape_two))
```
After:
```python
initializer = tf.initializers.truncated_normal(
mean=mean,
seed=seed,
stddev=stddev)
weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
```
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :-------------------- | :-------------- | :------------------------- |
| `mean` | `mean` | No change to defaults |
| `stddev` | `stddev` | Default changes from 1.0 to 0.05 |
| `seed` | `seed` | |
| `dtype` | `dtype` | The TF2 native api only takes it |
: : : as a `__call__` arg, not a constructor arg. :
| `partition_info` | - | (`__call__` arg in TF1) Not supported |
@end_compatibility
"""
@deprecated_args(None,
"Call initializer instance with the dtype argument instead "
"of passing it to the constructor", "dtype")
def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
self.mean = mean
self.stddev = stddev
self.seed = seed
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return random_ops.truncated_normal(
shape, self.mean, self.stddev, dtype, seed=self.seed)
def get_config(self):
return {
"mean": self.mean,
"stddev": self.stddev,
"seed": self.seed,
"dtype": self.dtype.name
}
@tf_export(v1=[
"initializers.uniform_unit_scaling", "uniform_unit_scaling_initializer"
])
@deprecation.deprecated_endpoints("uniform_unit_scaling_initializer",
"initializers.uniform_unit_scaling")
class UniformUnitScaling(Initializer):
"""Initializer that generates tensors without scaling variance.
When initializing a deep network, it is in principle advantageous to keep
the scale of the input variance constant, so it does not explode or diminish
by reaching the final layer. If the input is `x` and the operation `x * W`,
and we want to initialize `W` uniformly at random, we need to pick `W` from
[-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)]
to keep the scale intact, where `dim = W.shape[0]` (the size of the input).
A similar calculation for convolutional networks gives an analogous result
with `dim` equal to the product of the first 3 dimensions. When
nonlinearities are present, we need to multiply this by a constant `factor`.
See (Sussillo et al., 2014) for deeper motivation, experiments
and the calculation of constants. In section 2.3 there, the constants were
numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
Args:
factor: Float. A multiplicative factor by which the values will be scaled.
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: Default data type, used if no `dtype` argument is provided when
calling the initializer. Only floating point types are supported.
References:
[Sussillo et al., 2014](https://arxiv.org/abs/1412.6558)
([pdf](http://arxiv.org/pdf/1412.6558.pdf))
"""
@deprecated_args(None,
"Call initializer instance with the dtype argument instead "
"of passing it to the constructor", "dtype")
@deprecated(None,
"Use tf.initializers.variance_scaling instead with distribution="
"uniform to get equivalent behavior.")
def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32):
self.factor = factor
self.seed = seed
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
scale_shape = shape
if partition_info is not None:
scale_shape = partition_info.full_shape
input_size = 1.0
# Estimating input size is not possible to do perfectly, but we try.
# The estimate, obtained by multiplying all dimensions but the last one,
# is the right thing for matrix multiply and convolutions (see above).
for dim in scale_shape[:-1]:
input_size *= float(dim)
# Avoid errors when initializing zero-size tensors.
input_size = max(input_size, 1.0)
max_val = math.sqrt(3 / input_size) * self.factor
return random_ops.random_uniform(
shape, -max_val, max_val, dtype, seed=self.seed)
def get_config(self):
return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
@tf_export(v1=["initializers.variance_scaling", "variance_scaling_initializer"])
@deprecation.deprecated_endpoints("initializers.variance_scaling",
"variance_scaling_initializer")
class VarianceScaling(Initializer):
"""Initializer capable of adapting its scale to the shape of weights tensors.
@compatibility(TF2)
Although it is a legacy `compat.v1` API, this symbol is compatible with eager
execution and `tf.function`.
To switch to TF2 APIs, move to using either
`tf.initializers.variance_scaling` or `tf.keras.initializers.VarianceScaling`
(neither from `compat.v1`) and
pass the dtype when calling the initializer.
#### Structural Mapping to TF2
Before:
```python
initializer = tf.compat.v1.variance_scaling_initializer(
scale=scale,
mode=mode,
distribution=distribution
seed=seed,
dtype=dtype)
weight_one = tf.Variable(initializer(shape_one))
weight_two = tf.Variable(initializer(shape_two))
```
After:
```python
initializer = tf.keras.initializers.VarianceScaling(
scale=scale,
mode=mode,
distribution=distribution
seed=seed)
weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
```
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :----------------- | :-------------- | :------------------------- |
| `scale` | `scale` | No change to defaults |
| `mode` | `mode` | No change to defaults |
| `distribution` | `distribution` | No change to defaults. |
: : : 'normal' maps to 'truncated_normal' :
| `seed` | `seed` | |
| `dtype` | `dtype` | The TF2 api only takes it |
: : : as a `__call__` arg, not a constructor arg. :
| `partition_info` | - | (`__call__` arg in TF1) Not supported |
@end_compatibility
With `distribution="truncated_normal" or "untruncated_normal"`,
samples are drawn from a truncated/untruncated normal
distribution with a mean of zero and a standard deviation (after truncation,
if used) `stddev = sqrt(scale / n)`
where n is:
- number of input units in the weight tensor, if mode = "fan_in"
- number of output units, if mode = "fan_out"
- average of the numbers of input and output units, if mode = "fan_avg"
With `distribution="uniform"`, samples are drawn from a uniform distribution
within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
Args:
scale: Scaling factor (positive float).
mode: One of "fan_in", "fan_out", "fan_avg".
distribution: Random distribution to use. One of "normal", "uniform".
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: Default data type, used if no `dtype` argument is provided when
calling the initializer. Only floating point types are supported.
Raises:
ValueError: In case of an invalid value for the "scale", mode" or
"distribution" arguments.
"""
@deprecated_args(None,
"Call initializer instance with the dtype argument instead "
"of passing it to the constructor", "dtype")
@deprecated_arg_values(
None,
"`normal` is a deprecated alias for `truncated_normal`",
distribution="normal")
def __init__(self,
scale=1.0,
mode="fan_in",
distribution="truncated_normal",
seed=None,
dtype=dtypes.float32):
if scale <= 0.:
raise ValueError("Argument `scale` must be a positive float. Received: "
f"{scale}")
if mode not in {"fan_in", "fan_out", "fan_avg"}:
raise ValueError("Argument `mode` should be one of ('fan_in', 'fan_out', "
f"'fan_avg'). Received: {mode}")
distribution = distribution.lower()
if distribution not in {
"normal", "uniform", "truncated_normal", "untruncated_normal"
}:
raise ValueError("Argument `distribution` should be one of ('normal', "
"uniform', 'truncated_normal', 'untruncated_normal'). "
f"Received: {distribution}")
self.scale = scale
self.mode = mode
self.distribution = distribution
self.seed = seed
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
scale = self.scale
scale_shape = shape
if partition_info is not None:
scale_shape = partition_info.full_shape
fan_in, fan_out = _compute_fans(scale_shape)
if self.mode == "fan_in":
scale /= max(1., fan_in)
elif self.mode == "fan_out":
scale /= max(1., fan_out)
else:
scale /= max(1., (fan_in + fan_out) / 2.)
if self.distribution == "normal" or self.distribution == "truncated_normal":
# constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
stddev = math.sqrt(scale) / .87962566103423978
return random_ops.truncated_normal(
shape, 0.0, stddev, dtype, seed=self.seed)
elif self.distribution == "untruncated_normal":
stddev = math.sqrt(scale)
return random_ops.random_normal(shape, 0.0, stddev, dtype, seed=self.seed)
else:
limit = math.sqrt(3.0 * scale)
return random_ops.random_uniform(
shape, -limit, limit, dtype, seed=self.seed)
def get_config(self):
return {
"scale": self.scale,
"mode": self.mode,
"distribution": self.distribution,
"seed": self.seed,
"dtype": self.dtype.name
}
@tf_export(v1=["initializers.orthogonal", "orthogonal_initializer"])
@deprecation.deprecated_endpoints("initializers.orthogonal",
"orthogonal_initializer")
class Orthogonal(Initializer):
"""Initializer that generates an orthogonal matrix.
If the shape of the tensor to initialize is two-dimensional, it is initialized
with an orthogonal matrix obtained from the QR decomposition of a matrix of
random numbers drawn from a normal distribution.
If the matrix has fewer rows than columns then the output will have orthogonal
rows. Otherwise, the output will have orthogonal columns.
If the shape of the tensor to initialize is more than two-dimensional,
a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
is initialized, where `n` is the length of the shape vector.
The matrix is subsequently reshaped to give a tensor of the desired shape.
Args:
gain: multiplicative factor to apply to the orthogonal matrix
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: Default data type, used if no `dtype` argument is provided when
calling the initializer. Only floating point types are supported.
References:
[Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)
([pdf](https://arxiv.org/pdf/1312.6120.pdf))
"""
@deprecated_args(None,
"Call initializer instance with the dtype argument instead "
"of passing it to the constructor", "dtype")
def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
self.gain = gain
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
self.seed = seed
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
# Check the shape
if len(shape) < 2:
raise ValueError("The tensor to initialize, specified by argument `shape`"
" must be at least two-dimensional. Received shape="
f"{shape}")
# Flatten the input shape with the last dimension remaining
# its original shape so it works for conv2d
num_rows = 1
for dim in shape[:-1]:
num_rows *= dim
num_rows = int(num_rows)
num_cols = int(shape[-1])
if num_rows < num_cols:
flat_shape = (num_cols, num_rows)
else:
flat_shape = (num_rows, num_cols)
# Generate a random matrix
a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
# Compute the qr factorization
q, r = gen_linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
d = array_ops.diag_part(r)
q *= math_ops.sign(d)
if num_rows < num_cols:
q = array_ops.matrix_transpose(q)
return self.gain * array_ops.reshape(q, shape)
def get_config(self):
return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
# Note these haven't been ported to TF2.0. They are not currently visible and
# the tests are non trivial to port
class ConvolutionDeltaOrthogonal(Initializer):
"""Initializer that generates a delta orthogonal kernel for ConvNets.
The shape of the tensor must have length 3, 4 or 5. The number of input
filters must not exceed the number of output filters. The center pixels of the
tensor form an orthogonal matrix. Other pixels are set to be zero. See
algorithm 2 in (Xiao et al., 2018).
Args:
gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
The 2-norm of an input is multiplied by a factor of `gain` after applying
this convolution.
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: Default data type, used if no `dtype` argument is provided when
calling the initializer. Only floating point types are supported.
References:
[Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
"""
def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
self.gain = gain
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
self.seed = seed
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
# Check the shape
if len(shape) < 3 or len(shape) > 5:
raise ValueError("The tensor to initialize, specified by argument `shape`"
" must be at least three-dimensional and at most "