-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnn_impl.py
2420 lines (2083 loc) · 98.8 KB
/
nn_impl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Implementation of Neural Net (NN) functions."""
import math
from tensorflow.python.distribute import distribution_strategy_context as ds
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import candidate_sampling_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import cond as tf_cond
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import util as losses_util
from tensorflow.python.platform import device_context
from tensorflow.python.util import dispatch
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.deprecation import deprecated_argument_lookup
from tensorflow.python.util.tf_export import tf_export
@tf_export("nn.log_poisson_loss")
@dispatch.add_dispatch_support
def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
"""Computes log Poisson loss given `log_input`.
Gives the log-likelihood loss between the prediction and the target under the
assumption that the target has a Poisson distribution.
Caveat: By default, this is not the exact loss, but the loss minus a
constant term [log(z!)]. That has no effect for optimization, but
does not play well with relative loss comparisons. To compute an
approximation of the log factorial term, specify
compute_full_loss=True to enable Stirling's Approximation.
For brevity, let `c = log(x) = log_input`, `z = targets`. The log Poisson
loss is
-log(exp(-x) * (x^z) / z!)
= -log(exp(-x) * (x^z)) + log(z!)
~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
[ Note the second term is the Stirling's Approximation for log(z!).
It is invariant to x and does not affect optimization, though
important for correct relative loss comparisons. It is only
computed when compute_full_loss == True. ]
= x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
= exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
Args:
targets: A `Tensor` of the same type and shape as `log_input`.
log_input: A `Tensor` of type `float32` or `float64`.
compute_full_loss: whether to compute the full loss. If false, a constant
term is dropped in favor of more efficient optimization.
name: A name for the operation (optional).
Returns:
A `Tensor` of the same shape as `log_input` with the componentwise
logistic losses.
Raises:
ValueError: If `log_input` and `targets` do not have the same shape.
"""
with ops.name_scope(name, "log_poisson_loss", [log_input, targets]) as name:
log_input = ops.convert_to_tensor(log_input, name="log_input")
targets = ops.convert_to_tensor(targets, name="targets")
try:
targets.get_shape().assert_is_compatible_with(log_input.get_shape())
except ValueError:
raise ValueError(
"`log_input` and `targets` must have the same shape, received "
f"({log_input.get_shape()} vs {targets.get_shape()}).")
result = math_ops.exp(log_input) - log_input * targets
if compute_full_loss:
# need to create constant tensors here so that their dtypes can be matched
# to that of the targets.
point_five = constant_op.constant(0.5, dtype=targets.dtype)
two_pi = constant_op.constant(2 * math.pi, dtype=targets.dtype)
stirling_approx = (targets * math_ops.log(targets)) - targets + (
point_five * math_ops.log(two_pi * targets))
zeros = array_ops.zeros_like(targets, dtype=targets.dtype)
ones = array_ops.ones_like(targets, dtype=targets.dtype)
cond = math_ops.logical_and(targets >= zeros, targets <= ones)
result += array_ops.where(cond, zeros, stirling_approx)
return result
@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
def sigmoid_cross_entropy_with_logits(
labels=None,
logits=None,
name=None):
"""See sigmoid_cross_entropy_with_logits_v2."""
# pylint: disable=protected-access
nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits", labels, logits)
# pylint: enable=protected-access
with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
logits = ops.convert_to_tensor(logits, name="logits")
labels = ops.convert_to_tensor(labels, name="labels")
try:
labels.get_shape().assert_is_compatible_with(logits.get_shape())
except ValueError:
raise ValueError("`logits` and `labels` must have the same shape, "
f"received ({logits.get_shape()} vs "
f"{labels.get_shape()}).")
# The logistic loss formula from above is
# x - x * z + log(1 + exp(-x))
# For x < 0, a more numerically stable formula is
# -x * z + log(1 + exp(x))
# Note that these two expressions can be combined into the following:
# max(x, 0) - x * z + log(1 + exp(-abs(x)))
# To allow computing gradients at zero, we define custom versions of max and
# abs functions.
zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
cond = (logits >= zeros)
relu_logits = array_ops.where(cond, logits, zeros)
neg_abs_logits = array_ops.where(cond, -logits, logits) # pylint: disable=invalid-unary-operand-type
return math_ops.add(
relu_logits - logits * labels,
math_ops.log1p(math_ops.exp(neg_abs_logits)),
name=name)
# Note: intentionally calling this v2 to not allow existing code with indirect
# imports to ignore the sentinel behavior.
@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[])
@dispatch.register_binary_elementwise_api
@dispatch.add_dispatch_support
def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name
labels=None,
logits=None,
name=None):
r"""Computes sigmoid cross entropy given `logits`.
Measures the probability error in tasks with two outcomes in which each
outcome is independent and need not have a fully certain label. For instance,
one could perform a regression where the probability of an event happening is
known and used as a label. This loss may also be used for binary
classification, where labels are either zero or one.
For brevity, let `x = logits`, `z = labels`. The logistic loss is
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))
For x < 0, to avoid overflow in exp(-x), we reformulate the above
x - x * z + log(1 + exp(-x))
= log(exp(x)) - x * z + log(1 + exp(-x))
= - x * z + log(1 + exp(x))
Hence, to ensure stability and avoid overflow, the implementation uses this
equivalent formulation
max(x, 0) - x * z + log(1 + exp(-abs(x)))
`logits` and `labels` must have the same type and shape.
>>> logits = tf.constant([1., -1., 0., 1., -1., 0., 0.])
>>> labels = tf.constant([0., 0., 0., 1., 1., 1., 0.5])
>>> tf.nn.sigmoid_cross_entropy_with_logits(
... labels=labels, logits=logits).numpy()
array([1.3132617, 0.3132617, 0.6931472, 0.3132617, 1.3132617, 0.6931472,
0.6931472], dtype=float32)
Compared to the losses which handle multiple outcomes,
`tf.nn.softmax_cross_entropy_with_logits` for general multi-class
classification and `tf.nn.sparse_softmax_cross_entropy_with_logits` for more
efficient multi-class classification with hard labels,
`sigmoid_cross_entropy_with_logits` is a slight simplification for binary
classification:
sigmoid(x) = softmax([x, 0])[0]
$$\frac{1}{1 + e^{-x}} = \frac{e^x}{e^x + e^0}$$
While `sigmoid_cross_entropy_with_logits` works for soft binary labels
(probabilities between 0 and 1), it can also be used for binary classification
where the labels are hard. There is an equivalence between all three symbols
in this case, with a probability 0 indicating the second class or 1 indicating
the first class:
>>> sigmoid_logits = tf.constant([1., -1., 0.])
>>> softmax_logits = tf.stack([sigmoid_logits, tf.zeros_like(sigmoid_logits)],
... axis=-1)
>>> soft_binary_labels = tf.constant([1., 1., 0.])
>>> soft_multiclass_labels = tf.stack(
... [soft_binary_labels, 1. - soft_binary_labels], axis=-1)
>>> hard_labels = tf.constant([0, 0, 1])
>>> tf.nn.sparse_softmax_cross_entropy_with_logits(
... labels=hard_labels, logits=softmax_logits).numpy()
array([0.31326166, 1.3132616 , 0.6931472 ], dtype=float32)
>>> tf.nn.softmax_cross_entropy_with_logits(
... labels=soft_multiclass_labels, logits=softmax_logits).numpy()
array([0.31326166, 1.3132616, 0.6931472], dtype=float32)
>>> tf.nn.sigmoid_cross_entropy_with_logits(
... labels=soft_binary_labels, logits=sigmoid_logits).numpy()
array([0.31326166, 1.3132616, 0.6931472], dtype=float32)
Args:
labels: A `Tensor` of the same type and shape as `logits`. Between 0 and 1,
inclusive.
logits: A `Tensor` of type `float32` or `float64`. Any real number.
name: A name for the operation (optional).
Returns:
A `Tensor` of the same shape as `logits` with the componentwise
logistic losses.
Raises:
ValueError: If `logits` and `labels` do not have the same shape.
"""
return sigmoid_cross_entropy_with_logits(
logits=logits, labels=labels, name=name)
sigmoid_cross_entropy_with_logits.__doc__ = (
sigmoid_cross_entropy_with_logits_v2.__doc__)
@tf_export("nn.weighted_cross_entropy_with_logits", v1=[])
@dispatch.add_dispatch_support
def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
name=None):
"""Computes a weighted cross entropy.
This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
allows one to trade off recall and precision by up- or down-weighting the
cost of a positive error relative to a negative error.
The usual cross-entropy cost is defined as:
labels * -log(sigmoid(logits)) +
(1 - labels) * -log(1 - sigmoid(logits))
A value `pos_weight > 1` decreases the false negative count, hence increasing
the recall.
Conversely setting `pos_weight < 1` decreases the false positive count and
increases the precision.
This can be seen from the fact that `pos_weight` is introduced as a
multiplicative coefficient for the positive labels term
in the loss expression:
labels * -log(sigmoid(logits)) * pos_weight +
(1 - labels) * -log(1 - sigmoid(logits))
For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
The loss is:
qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
the implementation uses
(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
`logits` and `labels` must have the same type and shape.
>>> labels = tf.constant([1., 0.5, 0.])
>>> logits = tf.constant([1.5, -0.1, -10.])
>>> tf.nn.weighted_cross_entropy_with_logits(
... labels=labels, logits=logits, pos_weight=tf.constant(1.5)).numpy()
array([3.0211994e-01, 8.8049585e-01, 4.5776367e-05], dtype=float32)
>>> tf.nn.weighted_cross_entropy_with_logits(
... labels=labels, logits=logits, pos_weight=tf.constant(0.5)).numpy()
array([1.00706644e-01, 5.08297503e-01, 4.57763672e-05], dtype=float32)
Args:
labels: A `Tensor` of the same type and shape as `logits`, with values
between 0 and 1 inclusive.
logits: A `Tensor` of type `float32` or `float64`, any real numbers.
pos_weight: A coefficient to use on the positive examples, typically a
scalar but otherwise broadcastable to the shape of `logits`. Its value
should be non-negative.
name: A name for the operation (optional).
Returns:
A `Tensor` of the same shape as `logits` with the componentwise
weighted logistic losses.
Raises:
ValueError: If `logits` and `labels` do not have the same shape.
"""
with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
logits = ops.convert_to_tensor(logits, name="logits")
labels = ops.convert_to_tensor(labels, name="labels")
try:
labels.get_shape().assert_is_compatible_with(logits.get_shape())
except ValueError:
raise ValueError("`logits` and `labels` must have the same shape, "
f"received ({logits.get_shape()} vs "
f"{labels.get_shape()}).")
# The logistic loss formula from above is
# (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
# For x < 0, a more numerically stable formula is
# (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
# To avoid branching, we use the combined version
# (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
log_weight = 1 + (pos_weight - 1) * labels
return math_ops.add(
(1 - labels) * logits,
log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
nn_ops.relu(-logits)), # pylint: disable=invalid-unary-operand-type
name=name)
@tf_export(v1=["nn.weighted_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
@deprecated_args(None, "targets is deprecated, use labels instead", "targets")
def weighted_cross_entropy_with_logits(labels=None,
logits=None,
pos_weight=None,
name=None,
targets=None):
"""Computes a weighted cross entropy.
This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
allows one to trade off recall and precision by up- or down-weighting the
cost of a positive error relative to a negative error.
The usual cross-entropy cost is defined as:
labels * -log(sigmoid(logits)) +
(1 - labels) * -log(1 - sigmoid(logits))
A value `pos_weight > 1` decreases the false negative count, hence increasing
the recall.
Conversely setting `pos_weight < 1` decreases the false positive count and
increases the precision.
This can be seen from the fact that `pos_weight` is introduced as a
multiplicative coefficient for the positive labels term
in the loss expression:
labels * -log(sigmoid(logits)) * pos_weight +
(1 - labels) * -log(1 - sigmoid(logits))
For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
The loss is:
qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
the implementation uses
(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
`logits` and `labels` must have the same type and shape.
Args:
labels: A `Tensor` of the same type and shape as `logits`.
logits: A `Tensor` of type `float32` or `float64`.
pos_weight: A coefficient to use on the positive examples.
name: A name for the operation (optional).
targets: Deprecated alias for labels.
Returns:
A `Tensor` of the same shape as `logits` with the componentwise
weighted logistic losses.
Raises:
ValueError: If `logits` and `labels` do not have the same shape.
"""
labels = deprecated_argument_lookup("labels", labels, "targets", targets)
return weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, name)
@tf_export("nn.compute_average_loss")
@dispatch.add_dispatch_support
def compute_average_loss(per_example_loss,
sample_weight=None,
global_batch_size=None):
"""Scales per-example losses with sample_weights and computes their average.
Usage with distribution strategy and custom training loop:
```python
with strategy.scope():
def compute_loss(labels, predictions, sample_weight=None):
# If you are using a `Loss` class instead, set reduction to `NONE` so that
# we can do the reduction afterwards and divide by global batch size.
per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, predictions)
# Compute loss that is scaled by sample_weight and by global batch size.
return tf.nn.compute_average_loss(
per_example_loss,
sample_weight=sample_weight,
global_batch_size=GLOBAL_BATCH_SIZE)
```
Args:
per_example_loss: Per-example loss.
sample_weight: Optional weighting for each example.
global_batch_size: Optional global batch size value. Defaults to (size of
first dimension of `losses`) * (number of replicas).
Returns:
Scalar loss value.
""" # pylint: disable=g-doc-exception
per_example_loss = ops.convert_to_tensor(per_example_loss)
input_dtype = per_example_loss.dtype
with losses_util.check_per_example_loss_rank(per_example_loss):
if sample_weight is not None:
sample_weight = ops.convert_to_tensor(sample_weight)
per_example_loss = losses_util.scale_losses_by_sample_weight(
per_example_loss, sample_weight)
per_example_loss = math_ops.cast(per_example_loss, input_dtype)
if global_batch_size is None:
if ds.has_strategy() and ds.in_cross_replica_context():
raise RuntimeError(
"You are calling `compute_average_loss` in cross replica context, "
"while it was expected to be called in replica context.")
num_replicas = ds.get_strategy().num_replicas_in_sync
per_replica_batch_size = array_ops.shape_v2(per_example_loss)[0]
global_batch_size = per_replica_batch_size * num_replicas
check_ops.assert_scalar_v2(
global_batch_size, message="global_batch_size must be scalar.")
check_ops.assert_integer_v2(
global_batch_size,
message="global_batch_size must be an integer.")
check_ops.assert_positive_v2(
global_batch_size, message="global_batch_size must be positive.")
global_batch_size = math_ops.cast(global_batch_size, input_dtype)
return math_ops.reduce_sum(per_example_loss) / global_batch_size
@tf_export("nn.scale_regularization_loss")
@dispatch.add_dispatch_support
def scale_regularization_loss(regularization_loss):
"""Scales the sum of the given regularization losses by number of replicas.
Usage with distribution strategy and custom training loop:
```python
with strategy.scope():
def compute_loss(self, label, predictions):
per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, predictions)
# Compute loss that is scaled by sample_weight and by global batch size.
loss = tf.nn.compute_average_loss(
per_example_loss,
sample_weight=sample_weight,
global_batch_size=GLOBAL_BATCH_SIZE)
# Add scaled regularization losses.
loss += tf.nn.scale_regularization_loss(tf.nn.l2_loss(weights))
return loss
```
Args:
regularization_loss: Regularization loss.
Returns:
Scalar loss value.
""" # pylint: disable=g-doc-exception
if ds.has_strategy() and ds.in_cross_replica_context():
raise RuntimeError(
"You are calling `scale_regularization_loss` in cross replica context, "
"while it was expected to be called in replica context.")
num_replicas = ds.get_strategy().num_replicas_in_sync
return math_ops.reduce_sum(regularization_loss) / num_replicas
@tf_export(v1=["nn.relu_layer"])
@dispatch.add_dispatch_support
def relu_layer(x, weights, biases, name=None):
"""Computes Relu(x * weight + biases).
Args:
x: a 2D tensor. Dimensions typically: batch, in_units
weights: a 2D tensor. Dimensions typically: in_units, out_units
biases: a 1D tensor. Dimensions: out_units
name: A name for the operation (optional). If not specified
"nn_relu_layer" is used.
Returns:
A 2-D Tensor computing relu(matmul(x, weights) + biases).
Dimensions typically: batch, out_units.
"""
with ops.name_scope(name, "relu_layer", [x, weights, biases]) as name:
x = ops.convert_to_tensor(x, name="x")
weights = ops.convert_to_tensor(weights, name="weights")
biases = ops.convert_to_tensor(biases, name="biases")
xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
return nn_ops.relu(xw_plus_b, name=name)
@tf_export("nn.silu", "nn.swish")
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def swish(features, beta=1.0):
# pylint: disable=g-doc-args
"""Computes the SiLU or Swish activation function: `x * sigmoid(beta * x)`.
beta : Hyperparameter for Swish activation function. Default value 1.0.
The SiLU activation function was introduced in "Gaussian Error Linear Units
(GELUs)" [Hendrycks et al. 2016](https://arxiv.org/abs/1606.08415) and
"Sigmoid-Weighted Linear Units for Neural Network Function Approximation in
Reinforcement Learning"
[Elfwing et al. 2017](https://arxiv.org/abs/1702.03118) and was independently
discovered (and called swish) in "Searching for Activation Functions"
[Ramachandran et al. 2017](https://arxiv.org/abs/1710.05941)
Args:
features: A `Tensor` representing preactivation values.
beta: A 'Tensor' representing value of beta hyperparameter.
Returns:
The activation value.
"""
# pylint: enable=g-doc-args
features = ops.convert_to_tensor(features, name="features")
beta = ops.convert_to_tensor(beta, name="beta")
beta = math_ops.cast(beta, features.dtype)
@custom_gradient.custom_gradient
def swish_impl(features, beta):
def grad(dy):
"""Gradient for the Swish activation function."""
# Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x)
# around for backprop, effectively doubling the tensor's memory
# consumption. We use a control dependency here so that sigmoid(features)
# is re-computed during backprop (the control dep prevents it being
# de-duped with the forward pass) and we can free the sigmoid(features)
# expression immediately after use during the forward pass.
with ops.control_dependencies([dy]):
sigmoid_features = math_ops.sigmoid(beta * features)
activation_grad = (
sigmoid_features * (1.0 + (beta * features) *
(1.0 - sigmoid_features)))
beta_grad = math_ops.reduce_sum(
dy * math_ops.square(features) * sigmoid_features *
(1.0 - sigmoid_features))
return (dy * activation_grad, beta_grad)
return features * math_ops.sigmoid(beta * features), grad
return swish_impl(features, beta)
# pylint: disable=redefined-builtin
@tf_export("linalg.normalize")
@dispatch.add_dispatch_support
def normalize(tensor, ord="euclidean", axis=None, name=None):
"""Normalizes `tensor` along dimension `axis` using specified norm.
This uses `tf.linalg.norm` to compute the norm along `axis`.
This function can compute several different vector norms (the 1-norm, the
Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
Args:
tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, `1`,
`2`, `np.inf` and any positive real number yielding the corresponding
p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if
`tensor` is a matrix and equivalent to 2-norm for vectors.
Some restrictions apply: a) The Frobenius norm `'fro'` is not defined for
vectors, b) If axis is a 2-tuple (matrix norm), only `'euclidean'`,
'`fro'`, `1`, `2`, `np.inf` are supported. See the description of `axis`
on how to compute norms for a batch of vectors or matrices stored in a
tensor.
axis: If `axis` is `None` (the default), the input is considered a vector
and a single vector norm is computed over the entire set of values in the
tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
`norm(reshape(tensor, [-1]), ord=ord)`. If `axis` is a Python integer, the
input is considered a batch of vectors, and `axis` determines the axis in
`tensor` over which to compute vector norms. If `axis` is a 2-tuple of
Python integers it is considered a batch of matrices and `axis` determines
the axes in `tensor` over which to compute a matrix norm.
Negative indices are supported. Example: If you are passing a tensor that
can be either a matrix or a batch of matrices at runtime, pass
`axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
computed.
name: The name of the op.
Returns:
normalized: A normalized `Tensor` with the same shape as `tensor`.
norm: The computed norms with the same shape and dtype `tensor` but the
final axis is 1 instead. Same as running
`tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`.
Raises:
ValueError: If `ord` or `axis` is invalid.
"""
with ops.name_scope(name, "normalize", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor)
norm = linalg_ops.norm(tensor, ord, axis, keepdims=True)
norm = math_ops.cast(norm, tensor.dtype)
normalized = tensor / norm
return normalized, norm
@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize",
v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
@dispatch.add_dispatch_support
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
"""Normalizes along dimension `axis` using an L2 norm.
For a 1-D tensor with `axis = 0`, computes
output = x / sqrt(max(sum(x**2), epsilon))
For `x` with more dimensions, independently normalizes each 1-D slice along
dimension `axis`.
1-D tensor example:
>>> x = tf.constant([3.0, 4.0])
>>> tf.math.l2_normalize(x).numpy()
array([0.6, 0.8], dtype=float32)
2-D tensor example:
>>> x = tf.constant([[3.0], [4.0]])
>>> tf.math.l2_normalize(x, 0).numpy()
array([[0.6],
[0.8]], dtype=float32)
>>> x = tf.constant([[3.0], [4.0]])
>>> tf.math.l2_normalize(x, 1).numpy()
array([[1.],
[1.]], dtype=float32)
Args:
x: A `Tensor`.
axis: Dimension along which to normalize. A scalar or a vector of
integers.
epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
divisor if `norm < sqrt(epsilon)`.
name: A name for this operation (optional).
dim: Deprecated, do not use.
Returns:
A `Tensor` with the same shape as `x`.
"""
axis = deprecated_argument_lookup("axis", axis, "dim", dim)
with ops.name_scope(name, "l2_normalize", [x]) as name:
x = ops.convert_to_tensor(x, name="x")
if x.dtype.is_complex:
square_real = math_ops.square(math_ops.real(x))
square_imag = math_ops.square(math_ops.imag(x))
square_sum = math_ops.real(
math_ops.reduce_sum(square_real + square_imag, axis, keepdims=True))
x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
norm_real = math_ops.multiply(math_ops.real(x), x_inv_norm)
norm_imag = math_ops.multiply(math_ops.imag(x), x_inv_norm)
return math_ops.complex(norm_real, norm_imag, name=name)
square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
return math_ops.multiply(x, x_inv_norm, name=name)
def _count_nonzero(input_tensor, dtype=dtypes.int64):
"""Same as math_ops.count_nonzero.
The reduction is done in dtype, which can be faster for 32-bit dtypes.
Args:
input_tensor: numeric tensor
dtype: reduction dtype
Returns:
number of nonzero values with type dtype
"""
with ops.name_scope("count_nonzero", values=[input_tensor]):
zero = array_ops.zeros([], dtype=input_tensor.dtype)
nonzero_count = math_ops.reduce_sum(
math_ops.cast(
math_ops.not_equal(input_tensor, zero),
dtype=dtype), name="nonzero_count")
return nonzero_count
@tf_export("math.zero_fraction", "nn.zero_fraction")
@dispatch.add_dispatch_support
def zero_fraction(value, name=None):
"""Returns the fraction of zeros in `value`.
If `value` is empty, the result is `nan`.
This is useful in summaries to measure and report sparsity. For example,
```python
z = tf.nn.relu(...)
summ = tf.compat.v1.summary.scalar('sparsity', tf.nn.zero_fraction(z))
```
Args:
value: A tensor of numeric type.
name: A name for the operation (optional).
Returns:
The fraction of zeros in `value`, with type `float32`.
"""
with ops.name_scope(name, "zero_fraction", [value]):
value = ops.convert_to_tensor(value, name="value")
size = array_ops.size(value, out_type=dtypes.int64)
# If the count is small, we can save memory/CPU with an int32 reduction.
num_nonzero = tf_cond.cond(
size <= dtypes.int32.max,
# pylint: disable=g-long-lambda
true_fn=lambda: math_ops.cast(
_count_nonzero(value, dtype=dtypes.int32),
dtype=dtypes.int64),
false_fn=lambda: _count_nonzero(value, dtype=dtypes.int64))
with ops.name_scope("counts_to_fraction"):
num_zero = size - num_nonzero
num_zero_float32 = math_ops.cast(num_zero, dtype=dtypes.float32)
size_float32 = math_ops.cast(size, dtype=dtypes.float32)
zero_fraction_float32 = num_zero_float32 / size_float32
return array_ops.identity(zero_fraction_float32, "fraction")
# pylint: disable=redefined-builtin
@tf_export(v1=["nn.depthwise_conv2d"])
@dispatch.add_dispatch_support
def depthwise_conv2d(input,
filter,
strides,
padding,
rate=None,
name=None,
data_format=None,
dilations=None):
"""Depthwise 2-D convolution.
Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
and a filter tensor of shape
`[filter_height, filter_width, in_channels, channel_multiplier]`
containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
applies a different filter to each input channel (expanding from 1 channel
to `channel_multiplier` channels for each), then concatenates the results
together. The output has `in_channels * channel_multiplier` channels.
In detail, with the default NHWC format,
output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
strides[2] * j + rate[1] * dj, k]
Must have `strides[0] = strides[3] = 1`. For the most common case of the
same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
If any value in `rate` is greater than 1, we perform atrous depthwise
convolution, in which case all values in the `strides` tensor must be equal
to 1.
Usage Example:
>>> x = np.array([
... [1., 2.],
... [3., 4.],
... [5., 6.]
... ], dtype=np.float32).reshape((1, 3, 2, 1))
>>> kernel = np.array([
... [1., 2.],
... [3., 4]
... ], dtype=np.float32).reshape((2, 1, 1, 2))
>>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
... padding='VALID').numpy()
array([[[[10., 14.],
[14., 20.]],
[[18., 26.],
[22., 32.]]]], dtype=float32)
>>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
... padding=[[0, 0], [1, 0], [1, 0], [0, 0]]
... ).numpy()
array([[[[ 0., 0.],
[ 3., 4.],
[ 6., 8.]],
[[ 0., 0.],
[10., 14.],
[14., 20.]],
[[ 0., 0.],
[18., 26.],
[22., 32.]]]], dtype=float32)
Args:
input: 4-D with shape according to `data_format`.
filter: 4-D with shape
`[filter_height, filter_width, in_channels, channel_multiplier]`.
strides: 1-D of size 4. The stride of the sliding window for each
dimension of `input`.
padding: Controls how to pad the image before applying the convolution. Can
be the string `"SAME"` or `"VALID"` indicating the type of padding
algorithm to use, or a list indicating the explicit paddings at the start
and end of each dimension. When explicit padding is used and data_format
is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
[pad_left, pad_right], [0, 0]]`. When explicit padding used and
data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`.
rate: 1-D of size 2. The dilation rate in which we sample input values
across the `height` and `width` dimensions in atrous convolution. If it is
greater than 1, then all values of strides must be 1.
name: A name for this operation (optional).
data_format: The data format for input. Either "NHWC" (default) or "NCHW".
dilations: Alias of rate.
Returns:
A 4-D `Tensor` with shape according to `data_format`. E.g., for
"NHWC" format, shape is
`[batch, out_height, out_width, in_channels * channel_multiplier].`
"""
rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
with ops.name_scope(name, "depthwise", [input, filter]) as name:
input = ops.convert_to_tensor(input, name="tensor_in")
filter = ops.convert_to_tensor(filter, name="filter_in")
if rate is None:
rate = [1, 1]
# Use depthwise_conv2d_native if executing on TPU.
if device_context.enclosing_tpu_context() is not None:
if data_format == "NCHW":
dilations = [1, 1, rate[0], rate[1]]
else:
dilations = [1, rate[0], rate[1], 1]
return nn_ops.depthwise_conv2d_native(
input=input,
filter=filter,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations,
name=name)
def op(input_converted, _, padding):
return nn_ops.depthwise_conv2d_native(
input=input_converted,
filter=filter,
strides=strides,
padding=padding,
data_format=data_format,
name=name)
return nn_ops.with_space_to_batch(
input=input,
filter_shape=array_ops.shape(filter),
dilation_rate=rate,
padding=padding,
data_format=data_format,
op=op)
@tf_export("nn.depthwise_conv2d", v1=[])
@dispatch.add_dispatch_support
def depthwise_conv2d_v2(input,
filter,
strides,
padding,
data_format=None,
dilations=None,
name=None):
"""Depthwise 2-D convolution.
Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
and a filter tensor of shape
`[filter_height, filter_width, in_channels, channel_multiplier]`
containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
applies a different filter to each input channel (expanding from 1 channel
to `channel_multiplier` channels for each), then concatenates the results
together. The output has `in_channels * channel_multiplier` channels.
In detail, with the default NHWC format,
output[b, i, j, k * channel_multiplier + q] =
sum_{di, dj} filter[di, dj, k, q] *
input[b, strides[1] * i + dilations[0] * di,
strides[2] * j + dilations[1] * dj, k]
Must have `strides[0] = strides[3] = 1`. For the most common case of the
same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
If any value in `dilations` is greater than 1, we perform atrous depthwise
convolution, in which case all values in the `strides` tensor must be equal
to 1.
Usage Example:
>>> x = np.array([
... [1., 2.],
... [3., 4.],
... [5., 6.]
... ], dtype=np.float32).reshape((1, 3, 2, 1))
>>> kernel = np.array([
... [1., 2.],
... [3., 4]
... ], dtype=np.float32).reshape((2, 1, 1, 2))
>>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
... padding='VALID').numpy()
array([[[[10., 14.],
[14., 20.]],
[[18., 26.],
[22., 32.]]]], dtype=float32)
>>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
... padding=[[0, 0], [1, 0], [1, 0], [0, 0]]).numpy()
array([[[[ 0., 0.],
[ 3., 4.],
[ 6., 8.]],
[[ 0., 0.],
[10., 14.],
[14., 20.]],
[[ 0., 0.],
[18., 26.],
[22., 32.]]]], dtype=float32)
Args:
input: 4-D with shape according to `data_format`.
filter: 4-D with shape
`[filter_height, filter_width, in_channels, channel_multiplier]`.
strides: 1-D of size 4. The stride of the sliding window for each
dimension of `input`.
padding: Controls how to pad the image before applying the convolution. Can
be the string `"SAME"` or `"VALID"` indicating the type of padding
algorithm to use, or a list indicating the explicit paddings at the start
and end of each dimension. See
[here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
for more information. When explicit padding is used and data_format
is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
[pad_left, pad_right], [0, 0]]`. When explicit padding used and
data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`.
data_format: The data format for input. Either "NHWC" (default) or "NCHW".
dilations: 1-D of size 2. The dilation rate in which we sample input values
across the `height` and `width` dimensions in atrous convolution. If it is
greater than 1, then all values of strides must be 1.
name: A name for this operation (optional).
Returns:
A 4-D `Tensor` with shape according to `data_format`. E.g., for
"NHWC" format, shape is
`[batch, out_height, out_width, in_channels * channel_multiplier].`
"""
return depthwise_conv2d(input=input,
filter=filter,
strides=strides,
padding=padding,
rate=dilations,
name=name,
data_format=data_format)
# pylint: enable=redefined-builtin
# pylint: disable=redefined-builtin,line-too-long
@tf_export(v1=["nn.separable_conv2d"])